diff --git a/sample_solutions/HybridSearch/.env.example b/sample_solutions/HybridSearch/.env.example new file mode 100644 index 00000000..7c0120f5 --- /dev/null +++ b/sample_solutions/HybridSearch/.env.example @@ -0,0 +1,87 @@ +DEPLOYMENT_PHASE=production +SYSTEM_MODE=document + +# Local URL Endpoint (only needed for non-public domains) +# If using a local domain like api.example.com mapped to localhost: +# Set this to: api.example.com (domain without https://) +# If using a public domain, set any placeholder value like: not-needed +LOCAL_URL_ENDPOINT=not-needed + +# Service Ports +GATEWAY_PORT=8000 +EMBEDDING_PORT=8001 +RETRIEVAL_PORT=8002 +LLM_PORT=8003 +INGESTION_PORT=8004 +UI_PORT=8501 + +# Inference Gateway Configuration +# GENAI_GATEWAY_URL: Base URL to your inference gateway (without /v1 suffix) +# - For GenAI Gateway: https://genai-gateway.example.com +# - For APISIX Gateway: https://apisix-gateway.example.com +GENAI_GATEWAY_URL=https://api.example.com + +# GENAI_API_KEY: Authentication token/API key for the inference gateway +# - For GenAI Gateway: Your GenAI Gateway API key (litellm_master_key from vault.yml) +# To generate: use the generate-vault-secrets.sh script +# - For APISIX Gateway: Your APISIX authentication token +# To generate: use the generate-token.sh script (Keycloak client credentials) +GENAI_API_KEY=your-pre-generated-token-here + +# Model Configuration +# IMPORTANT: MODEL_ENDPOINT is the route/model identifier used by your gateway. +# - Xeon + Keycloak/APISIX: use the APISIX route name (e.g. bge-base-en-v1.5-vllmcpu) +# Run: kubectl get apisixroutes to find the exact route names for your deployment +# - Xeon + GenAI Gateway: use the HuggingFace model ID (e.g. BAAI/bge-base-en-v1.5) +# - Gaudi (TEI): use the HuggingFace model ID (e.g. BAAI/bge-base-en-v1.5) +# MODEL_NAME is always the HuggingFace model ID used in the API request payload. +# Check available models: curl https://your-gateway-url/v1/models -H "Authorization: Bearer your-token" +EMBEDDING_MODEL_ENDPOINT=BAAI/bge-base-en-v1.5 +EMBEDDING_MODEL_NAME=BAAI/bge-base-en-v1.5 +RERANKER_MODEL_ENDPOINT=BAAI/bge-reranker-base +RERANKER_MODEL_NAME=BAAI/bge-reranker-base +LLM_MODEL_ENDPOINT=Qwen/Qwen3-4B-Instruct-2507 +LLM_MODEL_NAME=Qwen/Qwen3-4B-Instruct-2507 + +# Inference Backend Type +# Set to "tei" for Gaudi hardware (TEI serves at /embeddings and /rerank — no /v1 prefix) +# Set to "vllm" for Xeon hardware (vLLM serves at /v1/embeddings and /v1/rerank) +INFERENCE_BACKEND=vllm + +# APISIX Gateway Per-Model Endpoints (required for Keycloak / APISIX deployments) +# Set these to the full APISIX route URL for each model. +# The route name matches the APISIX route (kubectl get apisixroutes). +# Xeon default route names use the -vllmcpu suffix. +# EMBEDDING_API_ENDPOINT=https://api.example.com/bge-base-en-v1.5-vllmcpu +# RERANKER_API_ENDPOINT=https://api.example.com/bge-reranker-base-vllmcpu +# LLM_API_ENDPOINT=https://api.example.com/Qwen3-4B-Instruct-2507-vllmcpu + +# Retrieval Configuration +USE_RERANKING=true +RERANKER_MAX_BATCH_SIZE=32 # Max docs per rerank request — reduce if your model has a lower limit +TOP_K_DENSE=100 +TOP_K_SPARSE=100 +TOP_K_FUSION=50 +TOP_K_RERANK=10 +RRF_K=60 + +# Embedding/Ingestion Configuration +EMBEDDING_BATCH_SIZE=32 # reduce for larger documents; must match embedding service batch size + +# Ingestion Configuration +CHUNK_SIZE=512 +CHUNK_OVERLAP=50 +MAX_FILE_SIZE_MB=100 +SUPPORTED_FORMATS=pdf,docx,xlsx,ppt,txt + +# UI Configuration +UI_TITLE=InsightMapper Lite +UI_PAGE_ICON= +UI_LAYOUT=wide + +# Logging +LOG_LEVEL=INFO + +# SSL Verification Settings +# Set to false only for dev with self-signed certs +VERIFY_SSL=true diff --git a/sample_solutions/HybridSearch/.gitattributes b/sample_solutions/HybridSearch/.gitattributes new file mode 100644 index 00000000..90b9bdf3 --- /dev/null +++ b/sample_solutions/HybridSearch/.gitattributes @@ -0,0 +1,15 @@ +# Git attributes for hybrid-search project +*.py text eol=lf +*.md text eol=lf +*.txt text eol=lf +*.yml text eol=lf +*.yaml text eol=lf +*.json text eol=lf +*.toml text eol=lf +*.sh text eol=lf +Dockerfile text eol=lf +.env* text eol=lf +*.pkl binary +*.bin binary +*.db binary +*.pdf binary diff --git a/sample_solutions/HybridSearch/.gitignore b/sample_solutions/HybridSearch/.gitignore new file mode 100644 index 00000000..1692772f --- /dev/null +++ b/sample_solutions/HybridSearch/.gitignore @@ -0,0 +1,99 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST +*.pyc + +# Virtual Environment +venv/ +ENV/ +env/ +.venv + +# Environment variables +.env +.env.local +.env.*.local +*.bak +.env.bak + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ +.DS_Store + +# Data directories +data/documents/* +data/indexes/* +data/*.db +data/*.sqlite +!data/documents/.gitkeep +!data/indexes/.gitkeep + +# Ingestion service data (documents, indexes, metadata) +api/ingestion/data/documents/* +api/ingestion/data/indexes/* +api/ingestion/data/*.db +api/ingestion/data/*.sqlite +!api/ingestion/data/documents/.gitkeep +!api/ingestion/data/indexes/.gitkeep + +# Model weights +models/ +*.bin +*.pt +*.pth +*.onnx +*.safetensors + +# Logs +logs/ +*.log + +# Jupyter Notebook +.ipynb_checkpoints + +# Docker +*.tar +.dockerignore + +# Testing +.pytest_cache/ +.coverage +htmlcov/ +.tox/ +.hypothesis/ + +# Monitoring +monitoring/data/ + +# Temporary files +tmp/ +temp/ +*.tmp + +# API Keys (extra safety) +**/api_key* +**/secret* +**/*secret* diff --git a/sample_solutions/HybridSearch/README.md b/sample_solutions/HybridSearch/README.md new file mode 100644 index 00000000..4489c180 --- /dev/null +++ b/sample_solutions/HybridSearch/README.md @@ -0,0 +1,387 @@ +# Hybrid Search RAG System + +A microservices-based RAG (Retrieval-Augmented Generation) system supporting both document Q&A and semantic product catalog search with hybrid retrieval (dense + sparse) with reranking. + +## Table of Contents + +- [Project Overview](#project-overview) +- [Features](#features) +- [Architecture](#architecture) +- [Prerequisites](#prerequisites) +- [Deployment & Configuration](#deployment--configuration) +- [Reranker Post-Deployment Configuration](#reranker-post-deployment-configuration) +- [User Interface](#user-interface) +- [Troubleshooting](#troubleshooting) +- [Additional Information](#additional-information) + +## Project Overview + +Hybrid Search RAG System is a microservices-based application that enables intelligent search and question-answering over both unstructured documents and structured product catalogs. The system combines dense vector search (FAISS) with sparse keyword search (BM25) using Reciprocal Rank Fusion (RRF) to deliver high-accuracy results. It seamlessly switches between document analysis and product discovery modes, leveraging powerful enterprise language models for generation. + +## Features + +- **Document RAG**: Upload PDFs, DOCX, XLSX, PPT and ask questions with citations +- **Product Catalog Search**: Upload product catalogs (CSV/JSON) and search with natural language +- **Hybrid Retrieval**: Combines FAISS (dense) and BM25 (sparse) search with RRF fusion +- **Dual Mode**: Switch between document and product modes seamlessly +- **Modern React web interface**: Streamlit-based interface with product grid and chat interface +- **RESTful API**: For integration with JSON-based communication + +## Architecture + +This application uses a microservices architecture where each service handles a specific part of the search and generation process. The Streamlit frontend communicates with a backend gateway that orchestrates requests across specialized services: embedding generation, retrieval, LLM processing, and data ingestion. + +```mermaid +graph TD + subgraph Client + A[Streamlit UI] + end + + subgraph Gateway + B[Gateway
Port 8000] + end + + subgraph Services + C[Embedding] + D[Retrieval
FAISS + BM25 + RRF] + E[LLM] + F[Ingestion] + end + + subgraph "Storage (/data volume)" + DOCS[(Documents
raw files)] + INDEX[(Unified Indexes
FAISS · BM25 · metadata.pkl
content_type = document or product)] + DB[(SQLite metadata.db)] + end + + %% Essential flows + A -->|query / search| B + A -->|upload| B + B -->|forward upload| F + B -->|embed + search| C + C --> D + B -->|hybrid search + filters| D + B -->|generate optional| E + F -->|chunk & embed| C + F -->|update indexes| INDEX + F -->|store files| DOCS + F -->|store metadata| DB + D -->|read indexes| INDEX + D -->|read metadata| DB + D --> B + E --> B + B -->|results / answer| A + + %% Notes + NOTE1[Ingests documents & product catalogs
sets content_type accordingly] + NOTE2[Filters by content_type
price/rating/category] + + F -.-> NOTE1 + D -.-> NOTE2 + + %% Styling + classDef client fill:#e8f5e8,stroke:#2e7d32,stroke-width:2px + classDef gateway fill:#fff8e1,stroke:#ff8f00,stroke-width:3px + classDef service fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + classDef storage fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px + classDef note fill:#fff,stroke:#666,stroke-width:1px,stroke-dasharray: 5 5 + + class A client + class B gateway + class C,D,E,F service + class DOCS,INDEX,DB storage + class NOTE1,NOTE2 note +``` + +``` +UI (8501) → Gateway (8000) → [Embedding (8001), Retrieval (8002), LLM (8003), Ingestion (8004)] +``` + +**Service Components:** + +- **UI (Port 8501)** - Streamlit web interface for file uploads, chat, and product browsing +- **Gateway (Port 8000)** - API orchestration and routing +- **Embedding Service (Port 8001)** - Vector generation using enterprise embedding models +- **Retrieval Service (Port 8002)** - Hybrid search implementation (FAISS + BM25 + RRF) +- **LLM Service (Port 8003)** - Question answering using enterprise LLMs +- **Ingestion Service (Port 8004)** - Document/product processing and indexing + +## Prerequisites + +### System Requirements +Before you begin, ensure you have the following installed: + +- Docker and Docker Compose +- Python 3.10+ (optional, for local development) + +### Required Models +The following models must be deployed on Enterprise Inference before running this application: + +| Model | Purpose | +|-------|---------| +| `BAAI/bge-base-en-v1.5` | Embedding generation for dense retrieval | +| `BAAI/bge-reranker-base` | Reranking retrieved results for improved relevance | +| `Qwen/Qwen3-4B-Instruct-2507` | LLM for question answering and generation | + +Verify your models are available: +```bash +curl https://your-gateway-url/v1/models -H "Authorization: Bearer your-token" +``` + +### Verify Docker Installation +```bash +# Check Docker version +docker --version + +# Check Docker Compose version +docker compose version + +# Verify Docker is running +docker ps +``` + +### Credentials & Authentication +The system uses GenAI Gateway for authentication with API key-based access. + +## Deployment & Configuration + +### Clone the Repository +```bash +git clone https://github.com/opea-project/Enterprise-Inference.git +cd Enterprise-Inference/sample_solutions/HybridSearch +``` + +### Set up the Environment +Create a `.env` file from the example and configure your credentials. + +```bash +cp .env.example .env +``` + +**Note:** The `LOCAL_URL_ENDPOINT` variable enables Docker's `extra_hosts` configuration for local domain resolution. Set it to `not-needed` if you don't require custom domain mapping. + +### Configure Authentication +Configure GenAI Gateway authentication in your `.env` file: + +```bash +# GenAI Gateway Configuration +GENAI_GATEWAY_URL=https://api.example.com +GENAI_API_KEY=your-api-key-here + +# SSL Verification (set to false only for dev with self-signed certs) +VERIFY_SSL=true +``` + +**Generating API tokens by gateway type:** + +- **GenAI Gateway**: Provide your GenAI Gateway URL and API key + - To generate the GenAI Gateway API key, use the [generate-vault-secrets.sh](https://github.com/opea-project/Enterprise-Inference/blob/main/core/scripts/generate-vault-secrets.sh) script + - The API key is the `litellm_master_key` value from the generated `vault.yml` file + +- **Keycloak / APISIX Gateway**: Provide your APISIX Gateway URL and authentication token + - To generate the APISIX authentication token, use the [generate-token.sh](https://github.com/opea-project/Enterprise-Inference/blob/main/core/scripts/generate-token.sh) script + - The token is generated using Keycloak client credentials (expires in 15 minutes by default; a Keycloak admin can configure longer-lived tokens in the Keycloak console) + - For Keycloak, each model has its own APISIX route path. Run `kubectl get apisixroutes` to find the route names for your deployed models (e.g., `bge-base-en-v1.5`, `bge-reranker-base`, `Qwen3-4B-Instruct-2507`) + +### Configure Models +Model endpoint names differ by deployment type. Use the table below to determine the correct values: + +| Variable | Xeon + Keycloak/APISIX | Xeon + GenAI Gateway | Gaudi (TEI) | +|---|---|---|---| +| `EMBEDDING_MODEL_ENDPOINT` | `bge-base-en-v1.5-vllmcpu` | `BAAI/bge-base-en-v1.5` | `BAAI/bge-base-en-v1.5` | +| `EMBEDDING_MODEL_NAME` | `BAAI/bge-base-en-v1.5` | `BAAI/bge-base-en-v1.5` | `BAAI/bge-base-en-v1.5` | +| `RERANKER_MODEL_ENDPOINT` | `bge-reranker-base-vllmcpu` | `BAAI/bge-reranker-base` | `BAAI/bge-reranker-base` | +| `RERANKER_MODEL_NAME` | `BAAI/bge-reranker-base` | `BAAI/bge-reranker-base` | `BAAI/bge-reranker-base` | +| `LLM_MODEL_ENDPOINT` | `Qwen3-4B-Instruct-2507-vllmcpu` | `Qwen/Qwen3-4B-Instruct-2507` | `Qwen/Qwen3-4B-Instruct-2507` | +| `LLM_MODEL_NAME` | `Qwen/Qwen3-4B-Instruct-2507` | `Qwen/Qwen3-4B-Instruct-2507` | `Qwen/Qwen3-4B-Instruct-2507` | + +> `MODEL_ENDPOINT` is the route/model identifier sent to your gateway. For Keycloak/APISIX it is the APISIX route name (run `kubectl get apisixroutes` to verify the exact names for your deployment). `MODEL_NAME` is always the HuggingFace model ID used in the API request payload. + +**Gaudi hardware (TEI backend):** Set `INFERENCE_BACKEND=tei` in your `.env`. TEI serves endpoints without the `/v1` prefix (`/embeddings`, `/rerank`) unlike vLLM which uses `/v1`. Xeon deployments use the default `INFERENCE_BACKEND=vllm`. + +```bash +# Gaudi hardware only +INFERENCE_BACKEND=tei +``` + +**Keycloak / APISIX deployments:** Uncomment and set the per-model API endpoint variables in your `.env`. Each model needs its own APISIX route URL. Xeon route names use the `-vllmcpu` suffix by default: + +```bash +# APISIX Gateway Per-Model Endpoints (required for Keycloak) +EMBEDDING_API_ENDPOINT=https://api.example.com/bge-base-en-v1.5-vllmcpu +RERANKER_API_ENDPOINT=https://api.example.com/bge-reranker-base-vllmcpu +LLM_API_ENDPOINT=https://api.example.com/Qwen3-4B-Instruct-2507-vllmcpu +``` + +### Running the Application +Start all services together with Docker Compose: + +```bash +# Start services in detached mode +docker compose up -d --build +``` + +This will: +- Build all microservices +- Create containers and internal networking +- Start services in detached mode + +### Check all containers are running: + +```bash +docker compose ps +``` + +Expected output shows services with status "Up". + +### View logs: + +```bash +# All services +docker compose logs -f + +# Individual services +docker compose logs -f ui +docker compose logs -f gateway +docker compose logs -f embedding +docker compose logs -f retrieval +docker compose logs -f llm +docker compose logs -f ingestion +``` + +Check each service started correctly: + +```bash +docker compose logs ui | grep -E "Running|Error|startup" +docker compose logs gateway | grep -E "Running|Error|startup" +docker compose logs embedding | grep -E "Running|Error|startup" +docker compose logs retrieval | grep -E "Running|Error|startup" +docker compose logs llm | grep -E "Running|Error|startup" +docker compose logs ingestion | grep -E "Running|Error|startup" +``` + +### Verify the services are running: + +```bash +# Check API health +curl http://localhost:8000/api/v1/health/services +``` + +## Reranker Post-Deployment Configuration + +> [!IMPORTANT] +> **GenAI Gateway + Xeon deployments only.** If you deployed Enterprise Inference with GenAI Gateway on Xeon hardware and have enabled reranking (`USE_RERANKING=true`), the `BAAI/bge-reranker-base` model requires a one-time post-deployment configuration step before it will work correctly. +> +> The deployment script registers the model with the wrong provider (`openai`) and without the required `mode: rerank` field. Without this fix, all rerank requests will return a `400 BadRequestError` from LiteLLM and reranking will silently fall back or fail. + +> [!NOTE] +> The following deployments do **not** require this step — the reranker works out of the box: +> - **Keycloak / APISIX** (Xeon or Gaudi): Set `RERANKER_API_ENDPOINT` in your `.env` to your APISIX route URL and ensure `USE_RERANKING=true` +> - **GenAI Gateway + Gaudi**: The reranker is pre-validated and works without LiteLLM reconfiguration + +The fix involves a single `curl` command to update the model registration in LiteLLM, changing the provider to `cohere` and setting `mode: rerank`. The full step-by-step workflow — including how to find the model UUID, the exact update payload, and how to verify the changes in the LiteLLM UI — is documented in: + +**[reranker-configuration.md](./reranker-configuration.md)** + +**Summary of what the configuration fixes:** + +| Field | Default (broken) | Required (correct) | +|---|---|---| +| LiteLLM provider | `openai` | `cohere` | +| Model mode | *(missing)* | `rerank` | +| Pass-through | `false` | `true` | +| API base | *(may include `/v1` suffix)* | `...vllm-service.default` (no `/v1`) | + +--- + +## User Interface + +### Using the Application + +Access the application at: http://localhost:8501 + +### Test the Application + +#### Document Mode + +![RAG Chatbot Interface](ui/public/rag_chatbot.png) + +1. Switch to "Documents" mode in the sidebar. +2. Upload PDF/DOCX/XLSX/PPT files. +3. Ask questions about the uploaded documents in the chat interface. +4. View answers with citations. + +**Document Q&A with Citations:** + +![Citations Example](ui/public/citations.png) + +The system provides answers with source references, showing which document chunks were used to generate the response. + +#### Product Catalog Mode + +![Product Catalog Search](ui/public/product_catalog.png) + +1. Switch to "Products" mode in the sidebar. +2. Upload a CSV/JSON product catalog. +3. Browse products in the grid view. +4. Search with natural language (e.g., "toys under $20" or "electronics with 4+ stars"). + +### Key Features + +The application provides a clean interface to toggle between modes: + +- **Document Mode**: Chat-based Q&A with document uploads and citation tracking +- **Product Mode**: Visual product grid with semantic search and filters +- **Seamless Switching**: Toggle between modes without losing context +- **Real-time Processing**: See indexing progress and status updates + +## Cleanup + +Stop all services: + +```bash +docker compose down +``` + +Remove all containers and volumes: + +```bash +docker compose down -v +``` + +## Troubleshooting + +**Services won't start:** +```bash +docker compose logs -f [service-name] +docker compose restart [service-name] +``` + +**Connection errors:** +- Verify all services are running: `docker compose ps` +- Check service health: `curl http://localhost:8000/api/v1/health/services` + +**Authentication errors:** +- Verify `GENAI_GATEWAY_URL` and `GENAI_API_KEY` in `.env` +- Ensure GenAI Gateway URL is correct and accessible + +**SSL certificate errors:** +- In production, keep `VERIFY_SSL=true` (default) +- For development environments with self-signed certificates, set `VERIFY_SSL=false` in `.env` + +**Index not loading:** +- Ensure `data/indexes/` directory exists and has write permissions +- Restart retrieval service after re-indexing + +## Additional Information + +The table provides a throughput comparison of LLM performance on **Intel Xeon** and **Intel Gaudi** machines, showing tokens processed, processing time, and effective tokens per second. + +| Model Name | Deployment Platform| Tokens Processed | Processing Time (sec)| Tokens/sec | Completion Status | +|------------------------------------|--------------------|------------------|----------------------|------------|----------------------------------| +| Qwen/Qwen3-8B | Xeon | 126 | 28.2 | 4.5 | Truncated at 50 tokens | +| Qwen/Qwen3-8B | Gaudi | 126 | 4.3 | 29.3 | Truncated at 50 tokens | +| ibm-granite/granite-3.3-8b-instruct| Xeon | 130 | 31 | 4.2 | Truncated at 50 tokens | +| ibm-granite/granite-3.3-8b-instruct| Gaudi | 130 | 4.5 | 28.8 | Truncated at 50 tokens | + +*All completions were cut off at 50 tokens due to the max_tokens setting. Tokens/sec is calculated as Tokens Processed divided by Processing Time (sec).* diff --git a/sample_solutions/HybridSearch/api/embedding/Dockerfile b/sample_solutions/HybridSearch/api/embedding/Dockerfile new file mode 100644 index 00000000..b6c05cad --- /dev/null +++ b/sample_solutions/HybridSearch/api/embedding/Dockerfile @@ -0,0 +1,33 @@ +# Embedding Service Dockerfile +FROM python:3.10-slim + +# Set working directory +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements +COPY requirements.txt . + +# Install Python dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code and create non-root user +COPY . . +RUN useradd -m -u 1000 appuser && \ + chown -R appuser:appuser /app +USER appuser + +# Expose port +EXPOSE 8001 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8001/health || exit 1 + +# Run the application +CMD ["python", "main.py"] + diff --git a/sample_solutions/HybridSearch/api/embedding/api_client.py b/sample_solutions/HybridSearch/api/embedding/api_client.py new file mode 100644 index 00000000..295a840b --- /dev/null +++ b/sample_solutions/HybridSearch/api/embedding/api_client.py @@ -0,0 +1,135 @@ +""" +API Client for GenAI Gateway authentication and enterprise API calls +""" + +import httpx +import logging +import re +from openai import OpenAI +from config import settings + +logger = logging.getLogger(__name__) + + +def clean_url(url: str) -> str: + """ + Remove invisible characters and whitespace from URL. + + Args: + url (str): The URL string to clean. + + Returns: + str: The cleaned URL string. + """ + if not url: + return url + # Remove non-printable characters, whitespace, and specific zero-width chars + return re.sub(r'[\x00-\x1f\x7f-\x9f\s\u200b\u2060\ufeff]+', '', url) + + +class APIClient: + """ + Client for handling GenAI Gateway authentication and API calls. + + This client manages API calls to GenAI Gateway or APISIX Gateway endpoints, + including embedding generation. + """ + + def __init__(self): + # Use per-model endpoint if set (APISIX), otherwise fall back to GenAI Gateway URL + base_url = settings.embedding_api_endpoint or settings.genai_gateway_url + self.base_url = clean_url(base_url).rstrip('/') if base_url else None + self.token = settings.genai_api_key + # TEI (Gaudi) does not use /v1 prefix; vLLM (Xeon) does + self.use_tei = settings.inference_backend.lower() == "tei" + self.http_client = httpx.Client(verify=settings.verify_ssl, timeout=30.0) if self.token else None + + if not self.token or not self.base_url: + raise ValueError("GenAI Gateway configuration missing. Check GENAI_GATEWAY_URL and GENAI_API_KEY.") + + logger.info(f"Using gateway at {self.base_url} (backend: {settings.inference_backend})") + + def get_embedding_client(self): + """ + Get OpenAI-style client for embeddings. + + Returns: + OpenAI: An instantiated OpenAI client configured for the GenAI Gateway. + """ + # TEI (Gaudi) serves at /embeddings; vLLM (Xeon) serves at /v1/embeddings + client_base_url = self.base_url if self.use_tei else f"{self.base_url}/v1" + logger.info(f"Creating OpenAI client with base_url: {client_base_url}") + + http_client = httpx.Client(verify=settings.verify_ssl, timeout=30.0) + + return OpenAI( + api_key=self.token, + base_url=client_base_url, + http_client=http_client + ) + + def generate_embeddings(self, texts: list[str], model: str) -> dict: + """ + Generate embeddings using raw HTTP request. + + Args: + texts (list[str]): List of texts to generate embeddings for. + model (str): Name of the model to use. + + Returns: + dict: The JSON response from the embedding API. + + Raises: + httpx.HTTPStatusError: If the API request fails. + """ + url = f"{self.base_url}/v1/embeddings" + + payload = { + "input": texts, + "model": model + } + + headers = { + "Authorization": f"Bearer {self.token}", + "Content-Type": "application/json" + } + + logger.info(f"Sending embedding request to {url}") + logger.info(f"Payload input length: {len(texts)}") + if len(texts) > 0: + logger.info(f"First text length: {len(texts[0])}") + logger.info(f"First text preview: {texts[0][:100]}...") + + response = self.http_client.post(url, json=payload, headers=headers) + + if response.status_code != 200: + logger.error(f"Embedding API error: {response.status_code} - {response.text}") + response.raise_for_status() + + return response.json() + + def is_authenticated(self) -> bool: + """ + Check if client is authenticated. + + Returns: + bool: True if authenticated, False otherwise. + """ + return bool(self.token and self.http_client) + + +# Global instance +_api_client = None + + +def get_api_client(): + """ + Get or create global API client instance. + + Returns: + APIClient: The global singleton instance of APIClient. + """ + global _api_client + if _api_client is None: + _api_client = APIClient() + return _api_client diff --git a/sample_solutions/HybridSearch/api/embedding/config.py b/sample_solutions/HybridSearch/api/embedding/config.py new file mode 100644 index 00000000..5ba1087c --- /dev/null +++ b/sample_solutions/HybridSearch/api/embedding/config.py @@ -0,0 +1,90 @@ +""" +Embedding Service Configuration +Manages environment variables and service settings +Supports GenAI Gateway and APISIX Gateway +""" + +from pydantic_settings import BaseSettings +from typing import Optional +from pathlib import Path + + +class Settings(BaseSettings): + """ + Service configuration with environment variable loading. + + This class defines the configuration settings for the Embedding Service, + including deployment phase and GenAI Gateway/APISIX Gateway settings. + """ + + # Deployment Phase + deployment_phase: str = "development" + + # GenAI Gateway Configuration + # Supports multiple deployment patterns: + # - GenAI Gateway: Provide GENAI_GATEWAY_URL and GENAI_API_KEY + # - APISIX Gateway: Provide GENAI_GATEWAY_URL and GENAI_API_KEY + genai_gateway_url: Optional[str] = None + genai_api_key: Optional[str] = None + + # Per-model endpoint URL (required for APISIX, optional for GenAI Gateway) + # APISIX: Set to full URL with model path, e.g. https://api.example.com/bge-base-en-v1.5 + # GenAI Gateway: Leave unset (falls back to GENAI_GATEWAY_URL) + embedding_api_endpoint: Optional[str] = None + + # Inference backend type: "vllm" (Xeon, default) or "tei" (Gaudi) + # TEI does not use the /v1 path prefix; vLLM does + inference_backend: str = "vllm" + + # Model Configuration (for Enterprise) + embedding_model_endpoint: str = "bge-large-en-v1.5-vllmcpu" + embedding_model_name: str = "BAAI/bge-large-en-v1.5" + + # Service Configuration + embedding_port: int = 8001 + embedding_host: str = "0.0.0.0" # nosec B104 - Binding to all interfaces is intentional for Docker container + embedding_batch_size: int = 32 + embedding_max_length: int = 512 + + # SSL Verification Settings + verify_ssl: bool = True + + # Logging + log_level: str = "INFO" + + class Config: + # Look for .env file in the hybrid-search root directory + env_file = Path(__file__).parent.parent.parent / ".env" + case_sensitive = False + extra = "ignore" # Ignore extra fields in .env file that aren't defined in this model + + def is_enterprise_configured(self) -> bool: + """ + Check if GenAI Gateway is configured. + + Returns: + bool: True if genai_gateway_url and genai_api_key are present. + """ + return bool(self.genai_gateway_url and self.genai_api_key) + + def validate_config(self): + """ + Validate that GenAI Gateway is configured. + + This service requires GenAI Gateway or APISIX Gateway authentication. + + Raises: + ValueError: If required configuration is missing. + """ + if not self.is_enterprise_configured(): + raise ValueError( + "GenAI Gateway configuration missing. " + "Must provide GENAI_GATEWAY_URL and GENAI_API_KEY in .env file." + ) + + +# Global settings instance +settings = Settings() + +# Validate configuration on import +settings.validate_config() diff --git a/sample_solutions/HybridSearch/api/embedding/main.py b/sample_solutions/HybridSearch/api/embedding/main.py new file mode 100644 index 00000000..49efdd27 --- /dev/null +++ b/sample_solutions/HybridSearch/api/embedding/main.py @@ -0,0 +1,381 @@ +""" +Embedding Service - OpenAI API Wrapper +Generates vector embeddings for documents and queries +""" + +import logging +import time +import math +from typing import List, Optional +from fastapi import FastAPI, HTTPException, status +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, Field +from openai import OpenAI, OpenAIError, RateLimitError, APIConnectionError, APITimeoutError +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, + before_sleep_log +) +from config import settings + +# Configure logging +logging.basicConfig( + level=settings.log_level, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# Initialize FastAPI app +app = FastAPI( + title="Embedding Service", + description="OpenAI-powered embedding generation service", + version="1.0.0" +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Initialize GenAI Gateway API client +try: + from api_client import get_api_client + + api_client = get_api_client() + + if not api_client.is_authenticated(): + raise RuntimeError("GenAI Gateway authentication failed - cannot start service without API access") + + client = api_client.get_embedding_client() + logger.info("✓ GenAI Gateway API client initialized successfully") + logger.info(f" Model: {settings.embedding_model_name}") + logger.info(f" Authentication: GenAI Gateway API Key") + logger.info(f" Base URL: {settings.genai_gateway_url}") + +except Exception as e: + logger.error(f"Failed to initialize GenAI Gateway API client: {e}") + logger.error("Service requires GenAI Gateway authentication and endpoints") + raise RuntimeError(f"GenAI Gateway API initialization failed: {e}") from e + +# Request/Response Models +class EmbeddingRequest(BaseModel): + """ + Request model for embedding generation. + + Attributes: + texts (List[str]): List of input text strings to embed. + normalize (bool): Whether to apply L2 normalization to the embeddings. + """ + texts: List[str] = Field(..., description="List of texts to embed", min_length=1) + normalize: bool = Field(True, description="Whether to L2 normalize embeddings") + + class Config: + json_schema_extra = { + "example": { + "texts": ["What is artificial intelligence?", "Machine learning basics"], + "normalize": True + } + } + + +class EmbeddingResponse(BaseModel): + """ + Response model for embedding generation. + + Attributes: + embeddings (List[List[float]]): The generated vector embeddings. + model (str): The name of the model used. + dimensions (int): The dimension size of the embeddings. + processing_time_ms (float): Time taken to generate embeddings in milliseconds. + text_count (int): The number of texts processed. + """ + embeddings: List[List[float]] = Field(..., description="Generated embeddings") + model: str = Field(..., description="Model used for embedding") + dimensions: int = Field(..., description="Embedding dimensions") + processing_time_ms: float = Field(..., description="Processing time in milliseconds") + text_count: int = Field(..., description="Number of texts processed") + + +class HealthResponse(BaseModel): + """Health check response""" + status: str + service: str + deployment_phase: str + model: str + dimensions: int + + +class ModelInfoResponse(BaseModel): + """Model information response""" + model: str + dimensions: int + max_input_length: int + batch_size: int + + +# Retry Configuration for OpenAI API +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)), + before_sleep=before_sleep_log(logger, logging.WARNING), + reraise=True +) +def _call_embeddings_api( + texts: List[str], + model: str, +): + """ + Call embeddings API with retry logic + + Args: + texts: List of texts to embed + model: Model name + + Returns: + Embeddings response + + Raises: + OpenAIError: If all retries fail + """ + try: + return client.embeddings.create( + model=model, + input=texts, + ) + except (RateLimitError, APIConnectionError, APITimeoutError) as e: + logger.warning(f"API error (will retry): {type(e).__name__}: {e}") + raise + except OpenAIError as e: + # Don't retry on other errors (invalid request, etc.) + logger.error(f"API error (non-retryable): {e}") + raise + + +# API Endpoints +@app.post( + "/api/v1/embeddings/encode", + response_model=EmbeddingResponse, + status_code=status.HTTP_200_OK, + summary="Generate embeddings for texts", + description="Generate vector embeddings for one or more texts using OpenAI API" +) +async def encode_embeddings(request: EmbeddingRequest): + """ + Generate embeddings for the provided texts. + + Args: + request (EmbeddingRequest): The request containing texts to embed. + + Returns: + EmbeddingResponse: Object containing generated embeddings and metadata. + + Raises: + HTTPException: If input validation fails or external API errors occur. + """ + try: + start_time = time.time() + + # Validate batch size + if len(request.texts) > settings.embedding_batch_size: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Batch size exceeds maximum of {settings.embedding_batch_size}" + ) + + # Validate text lengths + for idx, text in enumerate(request.texts): + if not text.strip(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Text at index {idx} is empty" + ) + + logger.info(f"Generating embeddings for {len(request.texts)} texts") + + # Use GenAI Gateway via OpenAI client + model_name = settings.embedding_model_name + + # Call API with retry logic + response = _call_embeddings_api( + texts=request.texts, + model=model_name, + ) + + # Extract embeddings + raw_embeddings = [item.embedding for item in response.data] + dimensions = len(raw_embeddings[0]) if raw_embeddings else 768 + + # Sanitize embeddings + embeddings = [] + for embedding in raw_embeddings: + sanitized_embedding = [] + for val in embedding: + if math.isnan(val) or math.isinf(val): + sanitized_embedding.append(0.0) + else: + sanitized_embedding.append(val) + embeddings.append(sanitized_embedding) + + # Calculate processing time + processing_time = (time.time() - start_time) * 1000 + + logger.info( + f"Successfully generated {len(embeddings)} embeddings " + f"in {processing_time:.2f}ms (dimensions={dimensions})" + ) + + return EmbeddingResponse( + embeddings=embeddings, + model=model_name, + dimensions=dimensions, + processing_time_ms=round(processing_time, 2), + text_count=len(request.texts) + ) + + except OpenAIError as e: + logger.error(f"OpenAI API error: {e}") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=f"OpenAI API error: {str(e)}" + ) + except HTTPException: + raise + except Exception as e: + logger.error(f"Unexpected error: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Internal server error: {str(e)}" + ) + + +@app.post( + "/api/v1/embeddings/encode-batch", + response_model=EmbeddingResponse, + status_code=status.HTTP_200_OK, + summary="Generate embeddings in batch", + description="Alias for /encode endpoint for batch processing" +) +async def encode_batch(request: EmbeddingRequest): + """ + Generate embeddings for multiple texts (alias for encode endpoint). + + This is functionally identical to /encode but provides a clearer + endpoint name for explicit batch operations. + + Args: + request (EmbeddingRequest): The request containing texts to embed. + + Returns: + EmbeddingResponse: Object containing generated embeddings and metadata. + """ + return await encode_embeddings(request) + + +@app.get( + "/health", + response_model=HealthResponse, + status_code=status.HTTP_200_OK, + summary="Health check", + description="Check if the embedding service is healthy" +) +async def health_check(): + """ + Health check endpoint. + + Returns: + HealthResponse: Status of the service and configuration details. + """ + if settings.is_enterprise_configured(): + model_name = settings.embedding_model_name + dimensions = 768 # BGE default + else: + model_name = "N/A" + dimensions = 0 + + return HealthResponse( + status="healthy", + service="embedding", + deployment_phase=settings.deployment_phase, + model=model_name, + dimensions=dimensions + ) + + +@app.get( + "/api/v1/embeddings/model-info", + response_model=ModelInfoResponse, + status_code=status.HTTP_200_OK, + summary="Get model information", + description="Get information about the embedding model being used" +) +async def get_model_info(): + """ + Get embedding model information. + + Returns: + ModelInfoResponse: Details about the model, dimensions, and limits. + """ + if settings.is_enterprise_configured(): + model_name = settings.embedding_model_name + dimensions = 768 # BGE default + else: + model_name = "N/A" + dimensions = 0 + + return ModelInfoResponse( + model=model_name, + dimensions=dimensions, + max_input_length=settings.embedding_max_length, + batch_size=settings.embedding_batch_size + ) + + +@app.get( + "/", + summary="Root endpoint", + description="Service information" +) +async def root(): + """ + Root endpoint with service information. + + Returns: + dict: Basic service info including version and status. + """ + return { + "service": "Embedding Service", + "version": "1.0.0", + "status": "running", + "docs": "/docs", + "health": "/health" + } + + +# Application startup +if __name__ == "__main__": + import uvicorn + + logger.info(f"Starting Embedding Service on {settings.embedding_host}:{settings.embedding_port}") + logger.info(f"Deployment phase: {settings.deployment_phase}") + + if settings.is_enterprise_configured(): + logger.info("Provider: GenAI Gateway") + logger.info(f"Model: {settings.embedding_model_name}") + logger.info("Dimensions: 768 (BGE default)") + else: + logger.warning("Provider: Not configured (GenAI Gateway required)") + + uvicorn.run( + app, + host=settings.embedding_host, # nosec B104 - Binding to all interfaces is intentional for Docker container + port=settings.embedding_port, + log_level=settings.log_level.lower() + ) + diff --git a/sample_solutions/HybridSearch/api/embedding/requirements.txt b/sample_solutions/HybridSearch/api/embedding/requirements.txt new file mode 100644 index 00000000..897b6f6e --- /dev/null +++ b/sample_solutions/HybridSearch/api/embedding/requirements.txt @@ -0,0 +1,30 @@ +# Embedding Service Requirements +# API Framework +fastapi==0.104.1 +uvicorn[standard]==0.24.0 +pydantic==2.5.0 +pydantic-settings==2.1.0 + +# OpenAI API +openai>=1.35.0 +httpx>=0.25.0 # Required for OpenAI client +requests>=2.32.0 # Security updates + +# Numpy for array operations +numpy==1.24.3 + +# Logging +python-json-logger==2.0.7 + +# Retry Logic +tenacity==8.2.3 + +# Environment +python-dotenv==1.0.0 + +# Production Phase Dependencies (will be used later) +# sentence-transformers==2.2.2 +# torch==2.1.0 +# transformers==4.35.0 +# intel-extension-for-pytorch==2.1.0 + diff --git a/sample_solutions/HybridSearch/api/gateway/Dockerfile b/sample_solutions/HybridSearch/api/gateway/Dockerfile new file mode 100644 index 00000000..bb5a31ce --- /dev/null +++ b/sample_solutions/HybridSearch/api/gateway/Dockerfile @@ -0,0 +1,33 @@ +# Gateway Service Dockerfile +FROM python:3.10-slim + +# Set working directory +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements +COPY requirements.txt . + +# Install Python dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code and create non-root user +COPY . . +RUN useradd -m -u 1000 appuser && \ + chown -R appuser:appuser /app +USER appuser + +# Expose port +EXPOSE 8000 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8000/health || exit 1 + +# Run the application +CMD ["python", "main.py"] + diff --git a/sample_solutions/HybridSearch/api/gateway/config.py b/sample_solutions/HybridSearch/api/gateway/config.py new file mode 100644 index 00000000..9afd54f3 --- /dev/null +++ b/sample_solutions/HybridSearch/api/gateway/config.py @@ -0,0 +1,58 @@ +""" +Gateway Service Configuration +Manages environment variables and service settings +""" + +from pydantic_settings import BaseSettings +from typing import Optional +from pathlib import Path + + +class Settings(BaseSettings): + """ + Service configuration with environment variable loading. + + This class defines configuration for the Gateway Service, including: + - Service host and port + - URLs for downstream services (embedding, retrieval, llm, ingestion) + - Product catalog specific settings + - Logging configuration + """ + + # Deployment Phase + deployment_phase: str = "development" + + # Service Configuration + gateway_port: int = 8000 + gateway_host: str = "0.0.0.0" # nosec B104 - Binding to all interfaces is intentional for Docker container + + # Service URLs + embedding_service_url: str = "http://localhost:8001" + retrieval_service_url: str = "http://localhost:8002" + llm_service_url: str = "http://localhost:8003" + ingestion_service_url: str = "http://localhost:8004" + + # Product Catalog Settings + system_mode: str = "document" # "document" or "product" + default_result_limit: int = 20 + + # Keycloak/Auth Configuration (optional) + base_url: Optional[str] = None + keycloak_realm: Optional[str] = None + + # SSL Verification Settings + verify_ssl: bool = True + + # Logging + log_level: str = "INFO" + + class Config: + # Look for .env file in the hybrid-search root directory + env_file = Path(__file__).parent.parent.parent / ".env" + case_sensitive = False + extra = "ignore" # Ignore extra fields in .env file + + +# Global settings instance +settings = Settings() + diff --git a/sample_solutions/HybridSearch/api/gateway/main.py b/sample_solutions/HybridSearch/api/gateway/main.py new file mode 100644 index 00000000..cbd0b9ed --- /dev/null +++ b/sample_solutions/HybridSearch/api/gateway/main.py @@ -0,0 +1,556 @@ +""" +Gateway Service +Main API orchestrator for the hybrid search system +""" + +import logging +import time +import os +from typing import Optional, Dict, Any, List +from fastapi import FastAPI, HTTPException, status, Request, Depends +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field, field_validator +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.util import get_remote_address +from slowapi.errors import RateLimitExceeded +from config import settings +from services.complexity_detector import ComplexityDetector +from services.orchestrator import ServiceOrchestrator +from services.query_analyzer import QueryAnalyzer +from services.filter_extractor import FilterExtractor +from services.auth import get_current_user + +# Configure logging +logging.basicConfig( + level=settings.log_level, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# Initialize FastAPI app +app = FastAPI( + title="Gateway Service", + description="Main API gateway for hybrid search RAG system", + version="1.0.0" +) + +# Initialize rate limiter +limiter = Limiter(key_func=get_remote_address) +app.state.limiter = limiter +app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + +# Configure CORS (environment-based) +cors_origins = os.getenv("CORS_ORIGINS", "*").split(",") if os.getenv("CORS_ORIGINS") != "*" else ["*"] +app.add_middleware( + CORSMiddleware, + allow_origins=cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Initialize components +complexity_detector = ComplexityDetector() +orchestrator = ServiceOrchestrator( + retrieval_service_url=settings.retrieval_service_url, + llm_service_url=settings.llm_service_url, + embedding_service_url=getattr(settings, 'embedding_service_url', None), + ingestion_service_url=getattr(settings, 'ingestion_service_url', None) +) +query_analyzer = QueryAnalyzer() +filter_extractor = FilterExtractor( + llm_service_url=settings.llm_service_url +) + + +# Request/Response Models +MAX_QUERY_LENGTH = int(os.getenv("MAX_QUERY_LENGTH", "5000")) + +class QueryRequest(BaseModel): + """ + Request model for main query endpoint. + + Attributes: + query (str): The user's natural language query. + top_k_results (int): Number of context chunks to retrieve (1-50). + force_model (Optional[str]): Force a specific processing strategy ('simple', 'complex'). + include_debug_info (bool): Whether to include detailed execution metadata in response. + """ + query: str = Field(..., description="User query", min_length=1, max_length=MAX_QUERY_LENGTH) + top_k_results: int = Field(10, description="Number of results to retrieve", ge=1, le=50) + force_model: Optional[str] = Field( + None, + description="Force specific model: 'simple', 'complex', or None for auto" + ) + include_debug_info: bool = Field(False, description="Include debug information") + + @field_validator('query') + @classmethod + def validate_query_length(cls, v: str) -> str: + """Validate query length""" + if len(v) > MAX_QUERY_LENGTH: + raise ValueError(f"Query exceeds maximum length of {MAX_QUERY_LENGTH} characters") + if not v.strip(): + raise ValueError("Query cannot be empty or whitespace only") + return v.strip() + + class Config: + json_schema_extra = { + "example": { + "query": "What are the main differences between Product A and Product B?", + "top_k_results": 10, + "force_model": None, + "include_debug_info": False + } + } + + +class QueryResponse(BaseModel): + """ + Response model for query endpoint. + + Attributes: + answer (str): The generated answer to the query. + citations (list): List of sources used to generate the answer. + confidence_score (float): Confidence score of the answer (0.0-1.0). + query_complexity (str): Detected or forced complexity level ('simple', 'complex'). + llm_model (str): Name of the LLM model used for generation. + retrieval_results_count (int): Number of results found in the retrieval step. + processing_time_ms (float): Total time taken to process the query. + debug_info (Optional[Dict]): Detailed execution metadata if requested. + """ + answer: str + citations: list + confidence_score: float = 0.0 + query_complexity: str + llm_model: str + retrieval_results_count: int + processing_time_ms: float + debug_info: Optional[Dict] = None + + +class ProductSearchRequest(BaseModel): + """ + Request model for product search. + + Attributes: + query (str): Natural language search query. + filters (Optional[Dict]): Explicit filters to apply (e.g., price, category). + limit (int): Maximum number of results to return (1-100). + offset (int): Pagination offset. + explain (bool): Whether to generate an LLM explanation of the results. + """ + query: str = Field(..., description="Search query", min_length=1) + filters: Optional[Dict] = Field(None, description="Additional filters") + limit: int = Field(20, description="Number of results", ge=1, le=100) + offset: int = Field(0, description="Offset for pagination", ge=0) + explain: bool = Field(True, description="Whether to generate LLM explanation") + + +class ProductSearchResponse(BaseModel): + """ + Response model for product search. + + Attributes: + query_interpretation (Dict): Analysis of the user's query (intent, entities). + applied_filters (Dict): The final filters applied to the search. + total_matches (int): Total number of matching products found. + results (List[Dict]): List of product results. + explanation (Optional[str]): Natural language explanation of the results. + suggested_refinements (List[str]): Suggestions to narrow down search results. + """ + query_interpretation: Dict + applied_filters: Dict + total_matches: int + results: List[Dict] + explanation: Optional[str] = None + suggested_refinements: List[str] = Field(default_factory=list) + + +class HealthResponse(BaseModel): + """ + Health check response. + + Attributes: + status (str): Status of the service ('healthy', etc.). + service (str): Name of the service. + deployment_phase (str): Current deployment environment. + """ + status: str + service: str + deployment_phase: str + + +class ServiceHealthResponse(BaseModel): + """ + Combined health check for all services. + + Attributes: + gateway (str): Status of the gateway service. + embedding (Dict): Status of the embedding service. + retrieval (Dict): Status of the retrieval service. + llm (Dict): Status of the LLM service. + ingestion (Dict): Status of the ingestion service. + """ + gateway: str + embedding: Dict + retrieval: Dict + llm: Dict + ingestion: Dict + + +# Global Exception Handler +@app.exception_handler(Exception) +async def global_exception_handler(request: Request, exc: Exception): + """ + Handle all unhandled exceptions globally. + + Args: + request (Request): The incoming request that caused the error. + exc (Exception): The unhandled exception. + + Returns: + JSONResponse: A 500 Internal Server Error response with error details. + """ + logger.error(f"Unhandled exception: {exc}", exc_info=True) + return JSONResponse( + status_code=500, + content={ + "error": "Internal server error", + "detail": "An unexpected error occurred. Please try again later.", + "type": type(exc).__name__ + } + ) + + +# API Endpoints +@app.post( + "/api/v1/query", + response_model=QueryResponse, + status_code=status.HTTP_200_OK, + summary="Main query endpoint", + description="Process query through full RAG pipeline" +) +@limiter.limit("100/minute") +async def process_query(request: Request, query_data: QueryRequest, user: dict = Depends(get_current_user)): + """ + Process query through the complete RAG pipeline: + 1. Detect query complexity + 2. Retrieve relevant context + 3. Generate answer with appropriate LLM + 4. Return answer with citations + + Args: + request: FastAPI Request object (for rate limiting) + query_data: QueryRequest with query and parameters + + Returns: + QueryResponse with answer and metadata + """ + try: + start_time = time.time() + debug_info = {} + + logger.info(f"Processing query: {query_data.query[:100]}") + + # 1. Detect query complexity + complexity_result = complexity_detector.detect(query_data.query) + query_complexity = query_data.force_model or complexity_result["complexity"] + + if query_data.include_debug_info: + debug_info["complexity_detection"] = complexity_result + + logger.info(f"Query complexity: {query_complexity} ({complexity_result['reasoning']})") + + # 2. Retrieve relevant context + retrieval_start = time.time() + retrieval_response = await orchestrator.retrieve_context( + query_data.query, + top_k=query_data.top_k_results + ) + retrieval_time = (time.time() - retrieval_start) * 1000 + + results = retrieval_response.get("results", []) + + if not results: + logger.warning("No results found in retrieval") + return QueryResponse( + answer="I don't have enough information to answer this question.", + citations=[], + query_complexity=query_complexity, + llm_model="none", + retrieval_results_count=0, + processing_time_ms=round((time.time() - start_time) * 1000, 2), + debug_info=debug_info if query_data.include_debug_info else None + ) + + if query_data.include_debug_info: + debug_info["retrieval"] = { + "results_count": len(results), + "retrieval_time_ms": round(retrieval_time, 2), + "top_scores": [r.get("score", 0) for r in results[:3]] + } + + # 3. Generate answer + llm_start = time.time() + llm_response = await orchestrator.generate_answer( + query_data.query, + results, + model_type=query_complexity + ) + llm_time = (time.time() - llm_start) * 1000 + + if query_data.include_debug_info: + debug_info["llm"] = { + "model_used": llm_response.get("model_used"), + "generation_time_ms": llm_response.get("generation_time_ms"), + "token_count": llm_response.get("token_count") + } + + # Calculate total processing time + total_time = (time.time() - start_time) * 1000 + + logger.info( + f"Query completed in {total_time:.2f}ms " + f"(retrieval: {retrieval_time:.2f}ms, llm: {llm_time:.2f}ms)" + ) + + return QueryResponse( + answer=llm_response.get("answer", ""), + citations=llm_response.get("citations", []), + query_complexity=llm_response.get("query_type", query_complexity), + llm_model=llm_response.get("model_used", ""), + retrieval_results_count=len(results), + processing_time_ms=round(total_time, 2), + debug_info=debug_info if query_data.include_debug_info else None + ) + + except Exception as e: + logger.error(f"Error processing query: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Query processing failed: {str(e)}" + ) + + +@app.post( + "/api/v1/query/explain", + response_model=QueryResponse, + status_code=status.HTTP_200_OK, + summary="Query with explanation", + description="Process query and include debug information" +) +@limiter.limit("50/minute") +async def process_query_with_explanation(request: Request, query_data: QueryRequest, user: dict = Depends(get_current_user)): + """ + Process query with debug information enabled (Explanation mode). + + This endpoint behaves like /query but forces debug_info to be included, + which provides insights into the retrieval and ranking process. + + Args: + request (Request): FastAPI Request object. + query_data (QueryRequest): Query parameters. + user (dict): Authenticated user context. + + Returns: + QueryResponse: Response with detailed debug information/explanation. + """ + query_data.include_debug_info = True + return await process_query(request, query_data) + + +@app.post( + "/api/v1/search", + response_model=ProductSearchResponse, + status_code=status.HTTP_200_OK, + summary="Product search", + description="Search products with natural language queries and filters" +) +@limiter.limit("100/minute") +async def search_products(request: Request, search_data: ProductSearchRequest, user: dict = Depends(get_current_user)): + """ + Search products with natural language queries and dynamic filtering. + + 1. Extracts filters from the natural language query. + 2. Analyzes query intent. + 3. Performs hybrid search on the product catalog. + 4. Optionally generates an LLM explanation of why results match. + + Args: + request (Request): FastAPI Request object (rate limiting). + search_data (ProductSearchRequest): Search query and manual filters. + user (dict): Authenticated user context. + + Returns: + ProductSearchResponse: Search results, extracted filters, and metadata. + + Raises: + HTTPException: If search fails or downstream services are unavailable. + """ + try: + start_time = time.time() + + logger.info(f"Processing product search: {search_data.query[:100]}") + + # Get catalog info to extract known categories + catalog_info = await orchestrator.get_catalog_info() + known_categories = catalog_info.get('categories', []) if catalog_info.get('loaded') else [] + + # Extract filters from query + extracted_filters = await filter_extractor.extract_async( + search_data.query, + known_categories=known_categories, + use_llm_fallback=True + ) + + # Merge with provided filters (provided filters take precedence) + applied_filters = extracted_filters.copy() + if search_data.filters: + applied_filters.update(search_data.filters) + + # Analyze query intent + query_analysis = query_analyzer.analyze(search_data.query, applied_filters) + + # Search products + search_results = await orchestrator.search_products( + query=query_analysis['semantic_query'], + filters=applied_filters, + limit=search_data.limit, + offset=search_data.offset + ) + + # Generate explanation if requested + explanation = None + if search_data.explain and search_results.get('results'): + # Call LLM for explanation + try: + llm_response = await orchestrator.generate_answer( + query=search_data.query, + context_chunks=search_results.get('results', [])[:5], # Use top 5 for context + model_type='simple' + ) + explanation = llm_response.get('answer', '') + except Exception as e: + logger.warning(f"Failed to generate explanation: {e}") + + # Generate suggested refinements + suggested_refinements = [] + if search_results.get('total_matches', 0) > search_data.limit: + if not applied_filters.get('categories') and known_categories: + suggested_refinements.append(f"Try filtering by {known_categories[0]}") + if not applied_filters.get('price_max'): + suggested_refinements.append("Narrow by price range") + + processing_time = (time.time() - start_time) * 1000 + + logger.info(f"Product search completed in {processing_time:.2f}ms") + + return ProductSearchResponse( + query_interpretation={ + "semantic_query": query_analysis['semantic_query'], + "extracted_filters": extracted_filters, + "intent": query_analysis['intent'] + }, + applied_filters=applied_filters, + total_matches=search_results.get('total_matches', len(search_results.get('results', []))), + results=search_results.get('results', []), + explanation=explanation, + suggested_refinements=suggested_refinements + ) + + except Exception as e: + logger.error(f"Error processing product search: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Product search failed: {str(e)}" + ) + + +@app.get( + "/health", + response_model=HealthResponse, + status_code=status.HTTP_200_OK, + summary="Health check" +) +async def health_check(): + """ + Gateway health check endpoint. + + Returns: + HealthResponse: Status of the gateway service itself. + """ + return HealthResponse( + status="healthy", + service="gateway", + deployment_phase=settings.deployment_phase + ) + + +@app.get( + "/api/v1/health/services", + response_model=ServiceHealthResponse, + status_code=status.HTTP_200_OK, + summary="Check all services health", + description="Check health of all backend services" +) +async def check_all_services(): + """ + Check health of all connected backend services. + + Queries the health endpoints of Embedding, Retrieval, LLM, and Ingestion services. + + Returns: + ServiceHealthResponse: Aggregate status of all services. + """ + embedding_health = await orchestrator.check_service_health(settings.embedding_service_url) + retrieval_health = await orchestrator.check_service_health(settings.retrieval_service_url) + llm_health = await orchestrator.check_service_health(settings.llm_service_url) + ingestion_health = await orchestrator.check_service_health(settings.ingestion_service_url) + + return ServiceHealthResponse( + gateway="healthy", + embedding=embedding_health, + retrieval=retrieval_health, + llm=llm_health, + ingestion=ingestion_health + ) + + +@app.get("/", summary="Root endpoint") +async def root(): + """ + Root endpoint with service information. + + Returns: + dict: Basic service info including version and status. + """ + return { + "service": "Gateway Service", + "version": "1.0.0", + "status": "running", + "docs": "/docs", + "health": "/health", + "deployment_phase": settings.deployment_phase + } + + +# Application startup +if __name__ == "__main__": + import uvicorn + + logger.info(f"Starting Gateway Service on {settings.gateway_host}:{settings.gateway_port}") + logger.info(f"Deployment phase: {settings.deployment_phase}") + logger.info(f"Embedding service: {settings.embedding_service_url}") + logger.info(f"Retrieval service: {settings.retrieval_service_url}") + logger.info(f"LLM service: {settings.llm_service_url}") + logger.info(f"Ingestion service: {settings.ingestion_service_url}") + + uvicorn.run( + app, + host=settings.gateway_host, # nosec B104 - Binding to all interfaces is intentional for Docker container + port=settings.gateway_port, + log_level=settings.log_level.lower() + ) + diff --git a/sample_solutions/HybridSearch/api/gateway/requirements.txt b/sample_solutions/HybridSearch/api/gateway/requirements.txt new file mode 100644 index 00000000..31c52448 --- /dev/null +++ b/sample_solutions/HybridSearch/api/gateway/requirements.txt @@ -0,0 +1,29 @@ +# Gateway Service Requirements +# API Framework +fastapi==0.104.1 +uvicorn[standard]==0.24.0 +pydantic==2.5.0 +pydantic-settings==2.1.0 + +# HTTP Client +httpx==0.25.1 +aiohttp>=3.11.7 + +# Logging and Monitoring +python-json-logger==2.0.7 +prometheus-client==0.19.0 + +# Retry Logic +tenacity==8.2.3 + +# Rate Limiting +slowapi==0.1.9 + +# Environment +python-dotenv==1.0.0 + +# Utilities +# Security +python-jose[cryptography]>=3.4.0 +python-multipart>=0.0.22 + diff --git a/sample_solutions/HybridSearch/api/gateway/services/__init__.py b/sample_solutions/HybridSearch/api/gateway/services/__init__.py new file mode 100644 index 00000000..9a148977 --- /dev/null +++ b/sample_solutions/HybridSearch/api/gateway/services/__init__.py @@ -0,0 +1,4 @@ +""" +Gateway Service Modules +""" + diff --git a/sample_solutions/HybridSearch/api/gateway/services/auth.py b/sample_solutions/HybridSearch/api/gateway/services/auth.py new file mode 100644 index 00000000..ff8041f8 --- /dev/null +++ b/sample_solutions/HybridSearch/api/gateway/services/auth.py @@ -0,0 +1,120 @@ +""" +Authentication Service +Handles JWT token verification from Keycloak +""" + +import logging +from typing import Optional, Dict, Any +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from jose import jwt, JWTError +import httpx +from config import settings + +logger = logging.getLogger(__name__) + +# OAuth2 scheme for token extraction +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) + +class AuthService: + """ + Enterprise authentication service using Keycloak JWTs. + + Handles validation of JWT tokens issued by Keycloak, including + signature verification (in production) and public key caching. + """ + + def __init__(self): + self.enabled = settings.deployment_phase == "production" + self.base_url = settings.base_url + self.realm = settings.keycloak_realm + + # In a real production environment, you should fetch the public key from Keycloak + # http://{keycloak}/realms/{realm}/protocol/openid-connect/certs + self.jwks_url = f"{self.base_url}/certs" if self.base_url else None + self._cached_keys = None + + async def _get_public_keys(self) -> Dict[str, Any]: + """ + Fetch and cache public keys from Keycloak. + + Returns: + Dict[str, Any]: The JSON Web Key Set (JWKS) from Keycloak. + """ + if self._cached_keys: + return self._cached_keys + + if not self.jwks_url: + return {} + + try: + async with httpx.AsyncClient(verify=settings.verify_ssl) as client: + response = await client.get(self.jwks_url) + response.raise_for_status() + self._cached_keys = response.json() + return self._cached_keys + except Exception as e: + logger.error(f"Failed to fetch public keys from Keycloak: {e}") + return {} + + async def verify_token(self, token: str) -> Optional[Dict[str, Any]]: + """ + Verify the JWT token. + + Args: + token (str): The raw JWT token string. + + Returns: + Optional[Dict[str, Any]]: Decoded token claims if valid. + + Raises: + HTTPException: If token is missing, invalid, or expired. + """ + # BYPASS AUTH FOR TESTING + return {"sub": "anonymous-user", "preferred_username": "dev-user", "roles": ["admin"]} + + if not self.enabled: + return {"sub": "anonymous", "preferred_username": "dev-user"} + + if not token: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentication token missing", + headers={"WWW-Authenticate": "Bearer"}, + ) + + try: + # In enterprise mode, verify against provided Keycloak + # Note: For simplicity, we are doing a loose check here. + # In a real-world scenario, you MUST verify signatures and audience. + + # 1. Decode without verification to get headers/claims + unverified_claims = jwt.get_unverified_claims(token) + + # 2. Add verification logic here (signatures, etc.) + # For this PoC, we rely on the Enterprise Gateway having validated the token already + # or we do a basic verification. + + return unverified_claims + + except JWTError as e: + logger.error(f"JWT verification failed: {e}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired token", + headers={"WWW-Authenticate": "Bearer"}, + ) + +# Dependency for FastAPI routes +async def get_current_user(token: str = Depends(oauth2_scheme)) -> Dict[str, Any]: + """ + FastAPI dependency to get and verify the current user. + + Args: + token (str): Bearer token from the Authorization header. + + Returns: + Dict[str, Any]: The authenticated user's claims. + """ + auth_service = AuthService() + return await auth_service.verify_token(token) diff --git a/sample_solutions/HybridSearch/api/gateway/services/complexity_detector.py b/sample_solutions/HybridSearch/api/gateway/services/complexity_detector.py new file mode 100644 index 00000000..90c7743a --- /dev/null +++ b/sample_solutions/HybridSearch/api/gateway/services/complexity_detector.py @@ -0,0 +1,91 @@ +""" +Query Complexity Detector +Determines if a query is simple or complex +""" + +import logging +from typing import Dict + +logger = logging.getLogger(__name__) + + +class ComplexityDetector: + """ + Detect query complexity for routing. + + Classifies queries as 'simple' or 'complex' based on keywords, + heuristics (length, question count), and structural patterns. + """ + + # Query patterns + SIMPLE_INDICATORS = [ + "what is", "who is", "when did", "when was", "where is", "where did", + "define", "list", "name", "how many", "how much", + "show me", "tell me", "give me" + ] + + COMPLEX_INDICATORS = [ + "compare", "analyze", "analyse", "explain why", "relationship between", + "impact of", "evaluate", "synthesize", "synthesise", + "how does", "affect", "effect", + "differences between", "similarities", "similar to", + "trend", "pattern", "correlation", "cause", "consequence", + "pros and cons", "advantages", "disadvantages", + "summarize", "summarise", "overview" + ] + + def detect(self, query: str) -> Dict[str, str]: + """ + Detect query complexity. + + Args: + query (str): The user query string. + + Returns: + Dict[str, str]: A dictionary containing: + - 'complexity': 'simple' or 'complex' + - 'reasoning': Explanation of the classification + """ + query_lower = query.lower().strip() + + # Check for complex indicators first (higher priority) + for indicator in self.COMPLEX_INDICATORS: + if indicator in query_lower: + logger.debug(f"Complex indicator found: '{indicator}'") + return { + "complexity": "complex", + "reasoning": f"Contains complex indicator: '{indicator}'" + } + + # Check for simple indicators + for indicator in self.SIMPLE_INDICATORS: + if indicator in query_lower: + logger.debug(f"Simple indicator found: '{indicator}'") + return { + "complexity": "simple", + "reasoning": f"Contains simple indicator: '{indicator}'" + } + + # Heuristic rules + word_count = len(query.split()) + question_count = query.count("?") + + # Long queries or multiple questions suggest complexity + if word_count > 15: + return { + "complexity": "complex", + "reasoning": f"Long query ({word_count} words)" + } + + if question_count > 1: + return { + "complexity": "complex", + "reasoning": f"Multiple questions ({question_count})" + } + + # Default to simple for short, direct questions + return { + "complexity": "simple", + "reasoning": "Short direct question (default)" + } + diff --git a/sample_solutions/HybridSearch/api/gateway/services/filter_extractor.py b/sample_solutions/HybridSearch/api/gateway/services/filter_extractor.py new file mode 100644 index 00000000..38ffe288 --- /dev/null +++ b/sample_solutions/HybridSearch/api/gateway/services/filter_extractor.py @@ -0,0 +1,310 @@ +""" +Filter Extractor +Extracts structured filters from natural language queries +""" + +import logging +import re +import httpx +from typing import Dict, List, Optional, Any +from config import settings + +logger = logging.getLogger(__name__) + + +class FilterExtractor: + """Extract filters from natural language queries""" + + def __init__(self, llm_service_url: str = None): + """ + Initialize filter extractor. + + Args: + llm_service_url (str, optional): URL of LLM service for complex extraction fallback. + Defaults to configuration settings if not provided. + """ + self.llm_service_url = llm_service_url or getattr(settings, 'llm_service_url', 'http://localhost:8003') + + # Price filter patterns + self.price_patterns = [ + (r'under\s+\$?(\d+(?:\.\d+)?)', 'price_max'), + (r'less\s+than\s+\$?(\d+(?:\.\d+)?)', 'price_max'), + (r'below\s+\$?(\d+(?:\.\d+)?)', 'price_max'), + (r'over\s+\$?(\d+(?:\.\d+)?)', 'price_min'), + (r'more\s+than\s+\$?(\d+(?:\.\d+)?)', 'price_min'), + (r'above\s+\$?(\d+(?:\.\d+)?)', 'price_min'), + (r'\$?(\d+(?:\.\d+)?)\s+to\s+\$?(\d+(?:\.\d+)?)', 'price_range'), + (r'between\s+\$?(\d+(?:\.\d+)?)\s+and\s+\$?(\d+(?:\.\d+)?)', 'price_range'), + (r'\$?(\d+(?:\.\d+)?)\s*-\s*\$?(\d+(?:\.\d+)?)', 'price_range'), + (r'around\s+\$?(\d+(?:\.\d+)?)', 'price_around'), + (r'about\s+\$?(\d+(?:\.\d+)?)', 'price_around'), + ] + + # Rating filter patterns + self.rating_patterns = [ + (r'(\d+(?:\.\d+)?)\+?\s*stars?', 'rating_min'), + (r'(\d+(?:\.\d+)?)\s+star\s+and\s+above', 'rating_min'), + (r'rated\s+(\d+(?:\.\d+)?)', 'rating_min'), + (r'highly\s+rated', 'rating_high'), + (r'top\s+rated', 'rating_high'), + (r'best\s+reviewed', 'rating_high'), + (r'well\s+reviewed', 'rating_well'), + ] + + # Quantity/limit patterns + self.limit_patterns = [ + (r'top\s+(\d+)', 'limit'), + (r'best\s+(\d+)', 'limit'), + (r'show\s+me\s+(\d+)', 'limit'), + (r'first\s+(\d+)', 'limit'), + ] + + def extract_price_filters(self, query: str) -> Dict[str, float]: + """ + Extract price filters from query using regex patterns. + + Args: + query (str): User query string. + + Returns: + Dict[str, float]: Dictionary containing: + - 'price_min': Minimum price filter + - 'price_max': Maximum price filter + """ + filters = {} + query_lower = query.lower() + + for pattern, filter_type in self.price_patterns: + match = re.search(pattern, query_lower, re.IGNORECASE) + if match: + if filter_type == 'price_max': + filters['price_max'] = float(match.group(1)) + elif filter_type == 'price_min': + filters['price_min'] = float(match.group(1)) + elif filter_type == 'price_range': + filters['price_min'] = float(match.group(1)) + filters['price_max'] = float(match.group(2)) + elif filter_type == 'price_around': + price = float(match.group(1)) + filters['price_min'] = price * 0.8 + filters['price_max'] = price * 1.2 + break # Use first match + + return filters + + def extract_rating_filters(self, query: str) -> Dict[str, float]: + """ + Extract rating filters from query using regex patterns. + + Args: + query (str): User query string. + + Returns: + Dict[str, float]: Dictionary containing 'rating_min' if found. + """ + filters = {} + query_lower = query.lower() + + for pattern, filter_type in self.rating_patterns: + match = re.search(pattern, query_lower, re.IGNORECASE) + if match: + if filter_type == 'rating_min': + filters['rating_min'] = float(match.group(1)) + elif filter_type == 'rating_high': + filters['rating_min'] = 4.0 + elif filter_type == 'rating_well': + filters['rating_min'] = 3.5 + break # Use first match + + return filters + + def extract_limit(self, query: str, default: int = 10) -> int: + """ + Extract result limit from query (e.g., "top 5 items"). + + Args: + query (str): User query string. + default (int): Default limit if no pattern matches. + + Returns: + int: The extracted limit or the default value. + """ + query_lower = query.lower() + + for pattern, _ in self.limit_patterns: + match = re.search(pattern, query_lower, re.IGNORECASE) + if match: + return int(match.group(1)) + + return default + + def extract_category_filters( + self, + query: str, + known_categories: List[str] = None + ) -> List[str]: + """ + Extract category filters from query using fuzzy matching against known categories. + + Args: + query (str): User query string. + known_categories (List[str]): List of valid categories in the catalog. + + Returns: + List[str]: List of matched category names. + """ + if not known_categories: + return [] + + query_lower = query.lower() + matched_categories = [] + + # Simple keyword matching + category_keywords = { + 'electronics': ['electronics', 'electronic', 'tech', 'technology'], + 'home': ['home', 'household', 'house'], + 'kitchen': ['kitchen', 'cooking', 'cookware'], + 'clothing': ['clothing', 'clothes', 'apparel', 'fashion'], + 'books': ['books', 'book', 'reading'], + 'sports': ['sports', 'sport', 'fitness', 'exercise'], + 'toys': ['toys', 'toy', 'games', 'game'], + } + + for category in known_categories: + category_lower = category.lower() + + # Direct match + if category_lower in query_lower: + matched_categories.append(category) + continue + + # Keyword match + for keyword, variations in category_keywords.items(): + if keyword in category_lower: + for variation in variations: + if variation in query_lower: + matched_categories.append(category) + break + if category in matched_categories: + break + + return list(set(matched_categories)) # Remove duplicates + + async def extract_with_llm(self, query: str) -> Dict[str, Any]: + """ + Extract filters using LLM (fallback for complex queries). + + Args: + query (str): User query string. + + Returns: + Dict[str, Any]: Dictionary containing extracted filters. + """ + try: + prompt = f"""Extract shopping filters from this query. Return JSON with: +- semantic_query: cleaned query without filters +- filters: object with price_min, price_max, rating_min, categories (array) +- intent: "semantic_browse", "filtered_search", "hybrid", "specific_product", or "comparison" + +Query: "{query}" + +Return only valid JSON, no other text.""" + + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.post( + f"{self.llm_service_url}/api/v1/llm/extract-filters", + json={"query": query, "prompt": prompt} + ) + + if response.status_code == 200: + data = response.json() + return data.get('filters', {}) + else: + logger.warning(f"LLM filter extraction failed: {response.status_code}") + return {} + + except Exception as e: + logger.warning(f"LLM filter extraction error: {e}") + return {} + + def extract( + self, + query: str, + known_categories: List[str] = None, + use_llm_fallback: bool = False + ) -> Dict[str, Any]: + """ + Extract all filters from query (synchronously). + + Combines regex-based extraction for price, rating, category, and limit. + Note: LLM fallback is skipped in synchronous orchestration unless called via async wrapper. + + Args: + query (str): User query string. + known_categories (List[str]): List of known categories. + use_llm_fallback (bool): Whether to attempt LLM fallback (skipped in sync method). + + Returns: + Dict[str, Any]: Dictionary containing all extracted filters. + """ + filters = {} + + # Extract price filters + price_filters = self.extract_price_filters(query) + filters.update(price_filters) + + # Extract rating filters + rating_filters = self.extract_rating_filters(query) + filters.update(rating_filters) + + # Extract category filters + if known_categories: + category_filters = self.extract_category_filters(query, known_categories) + if category_filters: + filters['categories'] = category_filters + + # Extract limit + limit = self.extract_limit(query) + if limit != 10: # Only include if different from default + filters['limit'] = limit + + # If no filters found and use_llm_fallback, try LLM + if not filters and use_llm_fallback: + # Note: This would be async, so we'd need to handle it differently + # For now, we'll skip LLM fallback in sync context + logger.debug("No filters found via regex, but LLM fallback requires async") + + logger.info(f"Extracted filters: {filters}") + return filters + + async def extract_async( + self, + query: str, + known_categories: List[str] = None, + use_llm_fallback: bool = True + ) -> Dict[str, Any]: + """ + Extract filters asynchronously (supports LLM fallback). + + First attempts regex-based extraction. If no filters are found and + fallback is enabled, queries the LLM service. + + Args: + query (str): User query string. + known_categories (List[str]): List of known categories. + use_llm_fallback (bool): Whether to use LLM for complex queries. + + Returns: + Dict[str, Any]: Dictionary with extracted filters. + """ + # First try regex extraction + filters = self.extract(query, known_categories, use_llm_fallback=False) + + # If no filters found and LLM fallback enabled, try LLM + if not filters and use_llm_fallback: + llm_filters = await self.extract_with_llm(query) + if llm_filters: + filters.update(llm_filters) + + return filters + diff --git a/sample_solutions/HybridSearch/api/gateway/services/orchestrator.py b/sample_solutions/HybridSearch/api/gateway/services/orchestrator.py new file mode 100644 index 00000000..87ebbf9a --- /dev/null +++ b/sample_solutions/HybridSearch/api/gateway/services/orchestrator.py @@ -0,0 +1,254 @@ +""" +Service Orchestrator +Coordinates calls to all backend services +""" + +import logging +import httpx +from typing import List, Dict, Any, Optional +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, + before_sleep_log +) + +logger = logging.getLogger(__name__) + + +class ServiceOrchestrator: + """ + Orchestrate calls to backend services. + + Manages communication with Retrieval, LLM, Embedding, and Ingestion services, + handling retries, timeouts, and error propagation. + """ + + def __init__( + self, + retrieval_service_url: str, + llm_service_url: str, + embedding_service_url: str = None, + ingestion_service_url: str = None + ): + """ + Initialize orchestrator. + + Args: + retrieval_service_url (str): URL of retrieval service. + llm_service_url (str): URL of LLM service. + embedding_service_url (str, optional): URL of embedding service. + ingestion_service_url (str, optional): URL of ingestion service. + """ + self.retrieval_service_url = retrieval_service_url + self.llm_service_url = llm_service_url + self.embedding_service_url = embedding_service_url + self.ingestion_service_url = ingestion_service_url + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=1, max=5), + retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException, httpx.ConnectError)), + before_sleep=before_sleep_log(logger, logging.WARNING), + reraise=True + ) + async def retrieve_context( + self, + query: str, + top_k: int = 10 + ) -> Dict[str, Any]: + """ + Retrieve relevant context for query with retry logic. + + Args: + query (str): The search query. + top_k (int): Number of results to retrieve. + + Returns: + Dict[str, Any]: Dictionary with retrieval results and metadata. + + Raises: + httpx.HTTPError: If retrieval service fails. + """ + try: + logger.info(f"Retrieving context for query: {query[:100]}") + + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + f"{self.retrieval_service_url}/api/v1/retrieve/hybrid", + json={ + "query": query, + "top_k_candidates": 100, + "top_k_fusion": 50, + "top_k_final": top_k + } + ) + response.raise_for_status() + return response.json() + + except (httpx.HTTPError, httpx.TimeoutException, httpx.ConnectError) as e: + logger.warning(f"Retrieval service error (will retry): {type(e).__name__}: {e}") + raise + except Exception as e: + logger.error(f"Unexpected error during retrieval: {e}") + raise Exception(f"Retrieval failed: {str(e)}") + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=1, max=5), + retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException, httpx.ConnectError)), + before_sleep=before_sleep_log(logger, logging.WARNING), + reraise=True + ) + async def generate_answer( + self, + query: str, + context_chunks: List[Dict], + model_type: str = "auto" + ) -> Dict[str, Any]: + """ + Generate answer using LLM with retry logic. + + Args: + query (str): The user query. + context_chunks (List[Dict]): Retrieved context chunks to use as grounding. + model_type (str): Model type strategy ('simple', 'complex', 'auto'). + + Returns: + Dict[str, Any]: Dictionary with generated answer, citations, and metadata. + + Raises: + httpx.HTTPError: If LLM service fails. + """ + try: + logger.info(f"Generating answer using {model_type} model") + + async with httpx.AsyncClient(timeout=120.0) as client: + response = await client.post( + f"{self.llm_service_url}/api/v1/llm/generate", + json={ + "query": query, + "context_chunks": context_chunks, + "model_type": model_type, + "include_citations": True + } + ) + response.raise_for_status() + return response.json() + + except (httpx.HTTPError, httpx.TimeoutException, httpx.ConnectError) as e: + logger.warning(f"LLM service error (will retry): {type(e).__name__}: {e}") + raise + except Exception as e: + logger.error(f"Unexpected error during answer generation: {e}") + raise Exception(f"Answer generation failed: {str(e)}") + + async def check_service_health(self, service_url: str) -> Dict: + """ + Check health of a downstream service. + + Args: + service_url (str): Base URL of the service to check. + + Returns: + Dict: Health status dictionary covering status and details. + """ + try: + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.get(f"{service_url}/health") + response.raise_for_status() + return { + "status": "healthy", + "details": response.json() + } + except Exception as e: + logger.error(f"Health check failed for {service_url}: {e}") + return { + "status": "unhealthy", + "error": str(e) + } + + async def search_products( + self, + query: str, + filters: Dict[str, Any], + limit: int = 20, + offset: int = 0 + ) -> Dict[str, Any]: + """ + Search products with filters. + + Coordinates embedding generation (if needed) and calls retrieval service. + + Args: + query (str): Search query. + filters (Dict[str, Any]): Extracted filters to apply. + limit (int): Max number of results. + offset (int): Pagination offset. + + Returns: + Dict[str, Any]: Product search results. + + Raises: + Exception: If embedding or retrieval fails. + """ + try: + logger.info(f"Searching products: query='{query}', filters={filters}") + + # Get query embedding + query_embedding = None + if self.embedding_service_url: + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post( + f"{self.embedding_service_url}/api/v1/embeddings/encode", + json={"texts": [query], "normalize": True} + ) + response.raise_for_status() + data = response.json() + # Extract first embedding from the list + embeddings = data.get("embeddings", []) + query_embedding = embeddings[0] if embeddings else None + + # Call retrieval service + async with httpx.AsyncClient(timeout=60.0) as client: + # retrieval service caps top_k at 100; guard here to avoid 422 + safe_top_k = min(limit * 5, 100) + response = await client.post( + f"{self.retrieval_service_url}/api/v1/search/products", + json={ + "query_embedding": query_embedding, + "query_text": query, + "filters": filters, + "top_k": safe_top_k + } + ) + response.raise_for_status() + return response.json() + + except Exception as e: + logger.error(f"Product search error: {e}") + raise Exception(f"Product search failed: {str(e)}") + + async def get_catalog_info(self) -> Dict[str, Any]: + """ + Get catalog information from ingestion service. + + Returns: + Dict[str, Any]: Dictionary with catalog metadata (categories, sizes, etc.). + """ + try: + if not self.ingestion_service_url: + return {"loaded": False, "error": "Ingestion service URL not configured"} + + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.get( + f"{self.ingestion_service_url}/api/v1/products/catalog/info" + ) + response.raise_for_status() + return response.json() + + except Exception as e: + logger.error(f"Error getting catalog info: {e}") + return {"loaded": False, "error": str(e)} + diff --git a/sample_solutions/HybridSearch/api/gateway/services/query_analyzer.py b/sample_solutions/HybridSearch/api/gateway/services/query_analyzer.py new file mode 100644 index 00000000..61e74db7 --- /dev/null +++ b/sample_solutions/HybridSearch/api/gateway/services/query_analyzer.py @@ -0,0 +1,180 @@ +""" +Query Analyzer +Classifies query intent types for product search +""" + +import logging +import re +from typing import Dict, Optional +from enum import Enum + +logger = logging.getLogger(__name__) + + +class QueryIntent(str, Enum): + """Query intent types""" + SEMANTIC_BROWSE = "semantic_browse" # Pure semantic search + FILTERED_SEARCH = "filtered_search" # Explicit constraints + HYBRID_SEARCH = "hybrid" # Semantic + filters + SPECIFIC_PRODUCT = "specific_product" # Exact product lookup + COMPARISON = "comparison" # Product comparison + + +class QueryAnalyzer: + """ + Analyze queries to determine intent and extract information. + + Uses regex patterns and heuristic rules to classify user intent + (e.g., specific product search, comparison, browsing) and cleaning + the query for semantic search. + """ + + def __init__(self): + """ + Initialize query analyzer. + + Sets up regex patterns for specific product matching and comparison detection. + """ + # Patterns for specific product queries + self.specific_product_patterns = [ + r'show me (the )?([A-Z][a-zA-Z0-9\s-]+)', + r'find (the )?([A-Z][a-zA-Z0-9\s-]+)', + r'where is (the )?([A-Z][a-zA-Z0-9\s-]+)', + r'([A-Z][a-zA-Z0-9\s-]+) (model|version|product)' + ] + + # Patterns for comparison queries + self.comparison_patterns = [ + r'compare', + r'difference between', + r'vs\.?', + r'versus', + r'which (is|are) better', + r'which (one|product) should', + r'help me choose', + r'help me decide' + ] + + def classify_intent(self, query: str, has_filters: bool = False) -> QueryIntent: + """ + Classify query intent based on patterns and context. + + Args: + query (str): User query string. + has_filters (bool): Whether filters have already been extracted. + + Returns: + QueryIntent: The classified intent enum value. + """ + query_lower = query.lower() + + # Check for specific product queries + for pattern in self.specific_product_patterns: + if re.search(pattern, query, re.IGNORECASE): + return QueryIntent.SPECIFIC_PRODUCT + + # Check for comparison queries + for pattern in self.comparison_patterns: + if re.search(pattern, query_lower): + return QueryIntent.COMPARISON + + # Determine based on filters + if has_filters: + # Check if query has semantic content beyond filters + # Simple heuristic: if query is mostly filter keywords, it's filtered_search + filter_keywords = ['under', 'over', 'above', 'below', 'between', 'stars', 'rated', 'category'] + query_words = set(query_lower.split()) + filter_word_count = sum(1 for word in filter_keywords if word in query_words) + + if filter_word_count > 2 or len(query_words) < 5: + return QueryIntent.FILTERED_SEARCH + else: + return QueryIntent.HYBRID_SEARCH + else: + return QueryIntent.SEMANTIC_BROWSE + + def extract_semantic_query(self, query: str, filters: Dict) -> str: + """ + Extract cleaned semantic query (with filters significant phrases removed). + + Removes parts of the query that correspond to extracted filters + (like "under $50" or "red category") to improve semantic search quality. + + Args: + query (str): Original query. + filters (Dict): Extracted filters dictionary. + + Returns: + str: Cleaned semantic query string. + """ + # Remove filter-related phrases + semantic_query = query + + # Remove price-related phrases + price_patterns = [ + r'under \$?\d+', + r'less than \$?\d+', + r'below \$?\d+', + r'over \$?\d+', + r'more than \$?\d+', + r'above \$?\d+', + r'\$?\d+\s*to\s*\$?\d+', + r'between \$?\d+\s+and\s+\$?\d+', + r'around \$?\d+', + r'about \$?\d+' + ] + + for pattern in price_patterns: + semantic_query = re.sub(pattern, '', semantic_query, flags=re.IGNORECASE) + + # Remove rating-related phrases + rating_patterns = [ + r'\d+\+?\s*stars?', + r'\d+\s*star\s+and\s+above', + r'rated\s+\d+', + r'highly\s+rated', + r'top\s+rated', + r'best\s+reviewed', + r'well\s+reviewed' + ] + + for pattern in rating_patterns: + semantic_query = re.sub(pattern, '', semantic_query, flags=re.IGNORECASE) + + # Remove category mentions if they're in filters + if filters.get('categories'): + for category in filters['categories']: + semantic_query = re.sub(category, '', semantic_query, flags=re.IGNORECASE) + + # Clean up whitespace + semantic_query = ' '.join(semantic_query.split()) + + return semantic_query.strip() if semantic_query.strip() else query + + def analyze(self, query: str, filters: Dict) -> Dict: + """ + Analyze query and return intent classification and cleaned query. + + Args: + query (str): User query. + filters (Dict): Extracted filters. + + Returns: + Dict: Dictionary containing: + - 'intent': Classified intent string + - 'semantic_query': Query processing for semantic search + - 'original_query': The raw input + """ + intent = self.classify_intent(query, has_filters=bool(filters)) + semantic_query = self.extract_semantic_query(query, filters) + + result = { + "intent": intent.value, + "semantic_query": semantic_query, + "original_query": query + } + + logger.info(f"Query analyzed: intent={intent.value}, semantic_query='{semantic_query}'") + + return result + diff --git a/sample_solutions/HybridSearch/api/ingestion/Dockerfile b/sample_solutions/HybridSearch/api/ingestion/Dockerfile new file mode 100644 index 00000000..cbe9a4af --- /dev/null +++ b/sample_solutions/HybridSearch/api/ingestion/Dockerfile @@ -0,0 +1,35 @@ +# Document Ingestion Service Dockerfile +FROM python:3.10-slim + +# Set working directory +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements +COPY requirements.txt . + +# Install Python dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code and create non-root user +COPY . . +RUN useradd -m -u 1000 appuser && \ + chown -R appuser:appuser /app && \ + mkdir -p /data/indexes /data/documents /data/db && \ + chown -R appuser:appuser /data +USER appuser + +# Expose port +EXPOSE 8004 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8004/health || exit 1 + +# Run the application +CMD ["python", "main.py"] + diff --git a/sample_solutions/HybridSearch/api/ingestion/config.py b/sample_solutions/HybridSearch/api/ingestion/config.py new file mode 100644 index 00000000..03ce9905 --- /dev/null +++ b/sample_solutions/HybridSearch/api/ingestion/config.py @@ -0,0 +1,76 @@ +""" +Document Ingestion Service Configuration +Manages environment variables and service settings +""" + +from pydantic_settings import BaseSettings +from pathlib import Path + +# Compute project root path (hybrid-search/) +# config.py is at: hybrid-search/api/ingestion/config.py +# So we need to go up 3 levels: ingestion -> api -> hybrid-search +_PROJECT_ROOT = Path(__file__).parent.parent.parent +_DEFAULT_DOCUMENT_PATH = str(_PROJECT_ROOT / "data" / "documents") +_DEFAULT_INDEX_PATH = str(_PROJECT_ROOT / "data" / "indexes") +_DEFAULT_DB_PATH = str(_PROJECT_ROOT / "data" / "metadata.db") + + +class Settings(BaseSettings): + """ + Service configuration with environment variable loading. + + This class defines configuration for the Ingestion Service, including: + - Service host and port + - Downstream service URLs + - Document storage paths + - File processing settings (chunk size, overlap, formats) + - Product catalog settings + """ + + # Deployment Phase + deployment_phase: str = "development" + + # Service Configuration + ingestion_port: int = 8004 + ingestion_host: str = "0.0.0.0" # nosec B104 - Binding to all interfaces is intentional for Docker container + + # Embedding Service + embedding_service_url: str = "http://localhost:8001" + + # Storage Paths (default to local development paths) + document_storage_path: str = _DEFAULT_DOCUMENT_PATH + index_storage_path: str = _DEFAULT_INDEX_PATH + metadata_db_path: str = _DEFAULT_DB_PATH + + # Document Processing + chunk_size: int = 256 # tokens (reduced to safe limit for 512-token models) + chunk_overlap: int = 25 # tokens + max_file_size_mb: int = 100 + supported_formats: str = "pdf,docx,xlsx,ppt,txt" + embedding_dim: int = 768 # BAAI/bge-base-en-v1.5 dimensions + embedding_batch_size: int = 32 # must match batch size from embedding service + + # Product Catalog Settings + system_mode: str = "document" # "document" or "product" + embedding_field_template: str = "{name}. {description}. Category: {category}. Brand: {brand}" + default_result_limit: int = 20 + max_products_per_catalog: int = 50000 + + # Logging + log_level: str = "INFO" + + class Config: + # Look for .env file in the hybrid-search root directory + env_file = Path(__file__).parent.parent.parent / ".env" + case_sensitive = False + extra = "ignore" # Ignore extra fields in .env file + + @property + def supported_formats_list(self) -> list: + """Get list of supported formats""" + return [fmt.strip() for fmt in self.supported_formats.split(",")] + + +# Global settings instance +settings = Settings() + diff --git a/sample_solutions/HybridSearch/api/ingestion/data/documents/.gitkeep b/sample_solutions/HybridSearch/api/ingestion/data/documents/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/sample_solutions/HybridSearch/api/ingestion/data/indexes/.gitkeep b/sample_solutions/HybridSearch/api/ingestion/data/indexes/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/sample_solutions/HybridSearch/api/ingestion/main.py b/sample_solutions/HybridSearch/api/ingestion/main.py new file mode 100644 index 00000000..e6292b9e --- /dev/null +++ b/sample_solutions/HybridSearch/api/ingestion/main.py @@ -0,0 +1,1148 @@ +""" +Document Ingestion Service +Handles document upload, processing, chunking, and indexing +""" + +import logging +import time +import uuid +import httpx +import asyncio +from pathlib import Path +from typing import List, Optional, Dict +from fastapi import FastAPI, File, UploadFile, HTTPException, status, Form +from contextlib import asynccontextmanager +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, Field +from config import settings +from services.document_parser import DocumentParser +from services.chunker import TextChunker +from services.index_manager import IndexManager +from services.metadata_store import MetadataStore +from services.product_parser import ProductParser +from services.product_processor import ProductProcessor +from schemas.product_schemas import ( + UploadResponse, ProcessingStatus, FieldMapping, + ProductCreate, CatalogMetadata +) + +# Configure logging +logging.basicConfig( + level=settings.log_level, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# Initialize FastAPI app +@asynccontextmanager +async def lifespan(app: FastAPI): + # Startup + yield + # Shutdown + metadata_store.close() + logger.info("Service shutdown complete") + +app = FastAPI( + title="Document Ingestion Service", + description="Document processing, chunking, and indexing service", + version="1.0.0", + lifespan=lifespan +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Initialize services +document_parser = DocumentParser() +text_chunker = TextChunker( + chunk_size=settings.chunk_size, + chunk_overlap=settings.chunk_overlap +) +index_manager = IndexManager( + index_storage_path=settings.index_storage_path, + embedding_dim=settings.embedding_dim +) +metadata_store = MetadataStore(settings.metadata_db_path) +product_parser = ProductParser() +product_processor = ProductProcessor( + embedding_field_template=getattr(settings, 'embedding_field_template', None) +) + +# Ensure storage directories exist +Path(settings.document_storage_path).mkdir(parents=True, exist_ok=True) + +# Job tracking for product processing +processing_jobs: Dict[str, Dict] = {} +system_mode: str = getattr(settings, 'system_mode', 'document') # 'document' or 'product' + + +# Response Models +class DocumentResponse(BaseModel): + """ + Response model for document upload. + + Attributes: + document_id (str): Unique identifier for the uploaded document. + filename (str): Name of the uploaded file. + file_type (str): Extension/type of the file. + status (str): Current processing status (e.g., 'processing'). + upload_timestamp (str): ISO timestamp of upload. + estimated_completion_time (Optional[str]): ETA for processing completion. + """ + document_id: str + filename: str + file_type: str + status: str + upload_timestamp: str + estimated_completion_time: Optional[str] = None + + +class DocumentStatus(BaseModel): + """ + Document processing status. + + Attributes: + document_id (str): Unique identifier. + filename (str): Name of the file. + file_type (str): Extension/type of the file. + processing_status (str): Current status ('processing', 'completed', 'failed'). + chunk_count (int): Number of chunks generated so far. + upload_timestamp (str): ISO timestamp of upload. + error_message (Optional[str]): detailed error if processing failed. + """ + document_id: str + filename: str + file_type: str + processing_status: str + chunk_count: int + upload_timestamp: str + error_message: Optional[str] = None + + +class IndexStats(BaseModel): + """ + Index statistics. + + Attributes: + total_documents (int): Total number of documents tracked. + total_chunks (int): Total number of chunks indexed. + faiss_vectors (int): Number of vectors in the FAISS index. + status_counts (dict): Breakdown of documents by status (completed, failed, etc.). + """ + total_documents: int + total_chunks: int + faiss_vectors: int + status_counts: dict + + +class HealthResponse(BaseModel): + """ + Health check response. + + Attributes: + status (str): Overall service status. + service (str): Service name. + deployment_phase (str): Deployment environment. + storage_paths (dict): Configuration of storage directories. + """ + status: str + service: str + deployment_phase: str + storage_paths: dict + + +# Helper Functions +async def get_embeddings(texts: List[str]) -> List[List[float]]: + """ + Call embedding service to get embeddings (with batching support). + + Args: + texts (List[str]): List of texts to embed. + + Returns: + List[List[float]]: List of embedding vectors. + + Raises: + httpx.HTTPError: If embedding service fails. + """ + BATCH_SIZE = settings.embedding_batch_size + logger.info(f"Embedding batch size: {BATCH_SIZE}") + all_embeddings = [] + + async with httpx.AsyncClient(timeout=120.0) as client: + # Process in batches if needed + for i in range(0, len(texts), BATCH_SIZE): + batch = texts[i:i + BATCH_SIZE] + logger.info(f"Getting embeddings for batch {i//BATCH_SIZE + 1}/{(len(texts)-1)//BATCH_SIZE + 1} ({len(batch)} texts)") + + response = await client.post( + f"{settings.embedding_service_url}/api/v1/embeddings/encode-batch", + json={"texts": batch, "normalize": True} + ) + response.raise_for_status() + data = response.json() + all_embeddings.extend(data["embeddings"]) + + return all_embeddings + + +async def process_document_async( + document_id: str, + file_path: Path, + file_type: str +): + """ + Process document asynchronously. + + Orchestrates parsing, chunking, embedding generation, and indexing. + Updates status in metadata store throughout the process. + + Args: + document_id (str): Unique document identifier. + file_path (Path): Path to the uploaded file on disk. + file_type (str): File extension/type (e.g., 'pdf'). + """ + try: + # Update status to processing + metadata_store.update_status(document_id, "processing") + + # Parse document + logger.info(f"Parsing document {document_id} ({file_type})") + pages_or_sections = document_parser.parse_document(file_path, file_type) + + if not pages_or_sections: + raise ValueError("No text content extracted from document") + + # Chunk text + logger.info(f"Chunking document {document_id}") + chunks = text_chunker.chunk_document(pages_or_sections, document_id) + + if not chunks: + raise ValueError("No chunks created from document") + + # Get embeddings + logger.info(f"Getting embeddings for {len(chunks)} chunks") + texts = [chunk["text"] for chunk in chunks] + embeddings = await get_embeddings(texts) + + # Add to indexes + logger.info(f"Adding {len(chunks)} chunks to indexes") + index_manager.add_chunks(chunks, embeddings) + + # Update status to completed + metadata_store.update_status( + document_id, + "completed", + chunk_count=len(chunks) + ) + + logger.info(f"Document {document_id} processed successfully ({len(chunks)} chunks)") + + except Exception as e: + logger.error(f"Error processing document {document_id}: {e}", exc_info=True) + metadata_store.update_status( + document_id, + "failed", + error_message=str(e) + ) + + +async def process_products_async( + job_id: str, + products: List[Dict], + field_mapping: FieldMapping, + catalog_name: str +): + """ + Process products asynchronously. + + Handles field mapping, validation, embedding generation, and indexing of product data. + Updates job status and progress. + + Args: + job_id (str): Unique job identifier. + products (List[Dict]): List of raw product dictionaries. + field_mapping (FieldMapping): Mapping configuration for product fields. + catalog_name (str): Name of the catalog. + """ + try: + # Update job status + processing_jobs[job_id]['status'] = 'processing' + processing_jobs[job_id]['current_step'] = 'Processing products...' + processing_jobs[job_id]['products_total'] = len(products) + + # Apply field mapping + logger.info(f"Applying field mapping for job {job_id}") + mapped_products = product_parser.apply_field_mapping(products, field_mapping) + + # Process and validate products + logger.info(f"Processing {len(mapped_products)} products for job {job_id}") + processed_products, invalid_products = product_processor.process_batch( + mapped_products, + batch_size=100, + skip_invalid=True + ) + + if invalid_products: + processing_jobs[job_id]['errors'].extend([ + f"Product {i+1}: {', '.join(err['errors'])}" + for i, err in enumerate(invalid_products) + ]) + + if not processed_products: + raise ValueError("No valid products found after processing") + + # Clear existing products (single catalog mode) + logger.info(f"Clearing existing products for job {job_id}") + metadata_store.clear_all_products() + index_manager.clear_products_only() # Only clear products, keep documents intact + + # Process in batches + BATCH_SIZE = 100 + total_batches = (len(processed_products) + BATCH_SIZE - 1) // BATCH_SIZE + + categories = set() + prices = [] + + for batch_idx in range(0, len(processed_products), BATCH_SIZE): + batch = processed_products[batch_idx:batch_idx + BATCH_SIZE] + batch_num = batch_idx // BATCH_SIZE + 1 + + # Update progress + processing_jobs[job_id]['products_processed'] = min(batch_idx + BATCH_SIZE, len(processed_products)) + processing_jobs[job_id]['progress'] = processing_jobs[job_id]['products_processed'] / len(processed_products) + processing_jobs[job_id]['current_step'] = f'Processing batch {batch_num}/{total_batches}...' + + # Collect categories and prices + for product in batch: + if product.get('category'): + categories.add(product['category']) + if product.get('price') is not None: + prices.append(product['price']) + + # Get embeddings for batch + logger.info(f"Getting embeddings for batch {batch_num}/{total_batches} ({len(batch)} products)") + embedding_texts = [p['embedding_text'] for p in batch] + embeddings = await get_embeddings(embedding_texts) + + # Add to database + for product, embedding in zip(batch, embeddings): + metadata_store.add_product( + product_id=product['id'], + name=product['name'], + description=product['description'], + category=product['category'], + price=product['price'], + rating=product['rating'], + review_count=product['review_count'], + image_url=product['image_url'], + brand=product['brand'], + embedding_text=product['embedding_text'] + ) + + # Add to product indexes + product_entries = [ + { + 'id': product['id'], + 'name': product['name'], + 'description': product['description'], + 'category': product['category'], + 'price': product['price'], + 'rating': product['rating'], + 'review_count': product['review_count'], + 'brand': product['brand'], + 'image_url': product['image_url'], + 'embedding_text': product['embedding_text'] + } + for product in batch + ] + index_manager.add_products(product_entries, embeddings) + + logger.info(f"Processed batch {batch_num}/{total_batches}") + + # Update catalog metadata + price_min = min(prices) if prices else None + price_max = max(prices) if prices else None + + metadata_store.update_catalog_metadata( + catalog_name=catalog_name, + product_count=len(processed_products), + categories=list(categories), + price_range_min=price_min, + price_range_max=price_max + ) + + # Update job status + processing_jobs[job_id]['status'] = 'complete' + processing_jobs[job_id]['progress'] = 1.0 + processing_jobs[job_id]['current_step'] = 'Complete' + + logger.info(f"Job {job_id} completed: {len(processed_products)} products processed") + + except Exception as e: + logger.error(f"Error processing products for job {job_id}: {e}", exc_info=True) + processing_jobs[job_id]['status'] = 'error' + processing_jobs[job_id]['errors'].append(str(e)) + processing_jobs[job_id]['current_step'] = f'Error: {str(e)}' + + +# API Endpoints +@app.post( + "/api/v1/documents/upload", + response_model=DocumentResponse, + status_code=status.HTTP_202_ACCEPTED, + summary="Upload document", + description="Upload a document for processing and indexing" +) +async def upload_document( + file: UploadFile = File(..., description="Document file to upload") +): + """ + Upload and process a document. + + Accepts a file upload, validates format and size, saves it locally, + and triggers asynchronous processing. + + Args: + file (UploadFile): The uploaded document file. + + Returns: + DocumentResponse: Processing task details including document ID. + + Raises: + HTTPException: For invalid files or I/O errors. + """ + try: + # Validate file type + if not file.filename: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Filename is required" + ) + + file_ext = Path(file.filename).suffix.lstrip('.').lower() + if file_ext not in settings.supported_formats_list: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unsupported file type: {file_ext}. " + f"Supported: {', '.join(settings.supported_formats_list)}" + ) + + # Read file content + try: + content = await file.read() + file_size = len(content) + logger.info(f"Read file: {file.filename}, size: {file_size} bytes") + except Exception as e: + logger.error(f"Error reading file: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to read file: {str(e)}" + ) + + # Validate file size + max_size = settings.max_file_size_mb * 1024 * 1024 + if file_size > max_size: + raise HTTPException( + status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, + detail=f"File size ({file_size / 1024 / 1024:.2f}MB) exceeds maximum " + f"({settings.max_file_size_mb}MB)" + ) + + if file_size == 0: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="File is empty" + ) + + # Generate document ID + document_id = f"doc_{uuid.uuid4().hex[:12]}" + + # Save file + try: + storage_path = Path(settings.document_storage_path) / document_id + storage_path.mkdir(parents=True, exist_ok=True) + + # Sanitize filename to avoid path issues + safe_filename = file.filename.replace('/', '_').replace('\\', '_') + file_path = storage_path / safe_filename + + with open(file_path, 'wb') as f: + f.write(content) + + logger.info(f"Saved file: {file_path} ({file_size} bytes)") + except PermissionError as e: + logger.error(f"Permission error saving file: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Permission denied saving file to {settings.document_storage_path}" + ) + except OSError as e: + logger.error(f"OS error saving file: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to save file: {str(e)}" + ) + + # Add to metadata store + try: + metadata_store.add_document( + document_id=document_id, + filename=file.filename, + file_type=file_ext, + file_size=file_size, + metadata={"original_filename": file.filename} + ) + except Exception as e: + logger.error(f"Error adding document to metadata store: {e}", exc_info=True) + # Try to clean up saved file + try: + if file_path.exists(): + file_path.unlink() + if storage_path.exists(): + storage_path.rmdir() + except: + pass + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to add document to metadata store: {str(e)}" + ) + + # Process document asynchronously + try: + import asyncio + asyncio.create_task(process_document_async(document_id, file_path, file_ext)) + except Exception as e: + logger.error(f"Error creating async task: {e}", exc_info=True) + # Don't fail the upload if async task creation fails + # The document is saved and can be processed later + + # Get document info + try: + doc = metadata_store.get_document(document_id) + if not doc: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Document saved but not found in metadata store" + ) + except Exception as e: + logger.error(f"Error retrieving document info: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to retrieve document info: {str(e)}" + ) + + return DocumentResponse( + document_id=document_id, + filename=file.filename, + file_type=file_ext, + status="processing", + upload_timestamp=doc['upload_timestamp'] + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Unexpected error uploading document: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to upload document: {str(e)}" + ) + + +@app.get( + "/api/v1/documents/{document_id}/status", + response_model=DocumentStatus, + status_code=status.HTTP_200_OK, + summary="Get document status", + description="Check the processing status of a document" +) +async def get_document_status(document_id: str): + """ + Get document processing status. + + Args: + document_id (str): Document identifier. + + Returns: + DocumentStatus: Current status and progress stats. + + Raises: + HTTPException: If document is not found. + """ + doc = metadata_store.get_document(document_id) + + if not doc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Document not found: {document_id}" + ) + + return DocumentStatus(**doc) + + +@app.delete( + "/api/v1/documents/clear-all", + status_code=status.HTTP_200_OK, + summary="Clear all indexes", + description="Clear all vector indexes (FAISS and BM25) and metadata. WARNING: This will delete all indexed documents!" +) +async def clear_all_indexes(): + """ + Clear all vector indexes and metadata. + + WARNING: This operation cannot be undone. All indexed documents + and associated metadata will be permanently removed. + + Returns: + dict: Success message. + """ + try: + # Clear indexes + index_manager.clear_all() + + # Clear metadata store + metadata_store.clear_all() + + logger.warning("All indexes and metadata cleared by user request") + + return { + "message": "All indexes and metadata cleared successfully", + "status": "success" + } + + except Exception as e: + logger.error(f"Error clearing indexes: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to clear indexes: {str(e)}" + ) + + +@app.delete( + "/api/v1/documents/{document_id}", + status_code=status.HTTP_200_OK, + summary="Delete document", + description="Delete a document and its associated data" +) +async def delete_document(document_id: str): + """ + Delete document and all associated data. + + Removes the document from vector indexes, metadata store, and file system. + + Args: + document_id (str): Document identifier. + + Returns: + dict: Success message. + + Raises: + HTTPException: If document is not found or deletion fails. + """ + doc = metadata_store.get_document(document_id) + + if not doc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Document not found: {document_id}" + ) + + try: + # Delete from indexes + index_manager.delete_document(document_id) + + # Delete file + storage_path = Path(settings.document_storage_path) / document_id + if storage_path.exists(): + import shutil + shutil.rmtree(storage_path) + + # Delete from metadata store + metadata_store.delete_document(document_id) + + logger.info(f"Deleted document: {document_id}") + + return {"message": f"Document {document_id} deleted successfully"} + + except Exception as e: + logger.error(f"Error deleting document {document_id}: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to delete document: {str(e)}" + ) + + +@app.get( + "/api/v1/documents/stats", + response_model=IndexStats, + status_code=status.HTTP_200_OK, + summary="Get index statistics", + description="Get statistics about indexed documents and chunks" +) +async def get_stats(): + """ + Get index statistics. + + Returns: + IndexStats: overview of total documents, chunks, and vector counts. + """ + db_stats = metadata_store.get_stats() + index_stats = index_manager.get_stats() + + return IndexStats( + total_documents=db_stats["total_documents"], + total_chunks=index_stats["total_chunks"], + faiss_vectors=index_stats["faiss_vectors"], + status_counts=db_stats["status_counts"] + ) + + +@app.get( + "/health", + response_model=HealthResponse, + status_code=status.HTTP_200_OK, + summary="Health check" +) +async def health_check(): + """ + Health check endpoint. + + Returns: + HealthResponse: Status of the service and storage path configurations. + """ + return HealthResponse( + status="healthy", + service="ingestion", + deployment_phase=settings.deployment_phase, + storage_paths={ + "documents": settings.document_storage_path, + "indexes": settings.index_storage_path, + "metadata": settings.metadata_db_path + } + ) + + +# Product Endpoints +@app.post( + "/api/v1/products/upload", + response_model=UploadResponse, + status_code=status.HTTP_202_ACCEPTED, + summary="Upload product catalog", + description="Upload CSV/JSON/XLSX file with products. Returns job_id for async processing." +) +async def upload_products( + file: UploadFile = File(..., description="Product catalog file (CSV/JSON/XLSX)") +): + """ + Upload product catalog file (CSV/JSON/XLSX). + + Parses the file to detect columns and generates a suggested field mapping. + + Args: + file (UploadFile): The product catalog file. + + Returns: + UploadResponse: Job ID and mapping suggestions. + + Raises: + HTTPException: For unsupported formats or parsing errors. + """ + try: + # Validate file type + if not file.filename: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Filename is required" + ) + + file_ext = Path(file.filename).suffix.lstrip('.').lower() + if file_ext not in ['csv', 'json', 'xlsx', 'xls']: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unsupported file type: {file_ext}. Supported: csv, json, xlsx" + ) + + # Read file content + content = await file.read() + if len(content) == 0: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="File is empty" + ) + + # Parse file + try: + products, detected_columns, suggested_mapping = product_parser.parse_file(content, file.filename) + except Exception as e: + logger.error(f"Error parsing file: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Failed to parse file: {str(e)}" + ) + + if not products: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="No products found in file" + ) + + # Ensure detected_columns are strings (not tuples) + detected_columns = [str(col) if not isinstance(col, str) else col for col in detected_columns] + + # Generate job ID + job_id = str(uuid.uuid4()) + + # Store job info + processing_jobs[job_id] = { + 'status': 'pending_confirmation', + 'filename': file.filename, + 'products': products, + 'detected_columns': detected_columns, + 'suggested_mapping': suggested_mapping, + 'products_processed': 0, + 'products_total': len(products), + 'progress': 0.0, + 'current_step': 'Waiting for field mapping confirmation', + 'errors': [] + } + + requires_confirmation = not suggested_mapping.name + + # If confirmation is not required, start processing immediately + if not requires_confirmation: + job = processing_jobs[job_id] + job['status'] = 'processing' + job['field_mapping'] = suggested_mapping + job['catalog_name'] = Path(file.filename).stem + + # Apply field mapping + mapped_products = product_parser.apply_field_mapping( + products, + suggested_mapping + ) + + # Start async processing + asyncio.create_task(process_products_async( + job_id, + mapped_products, + suggested_mapping, + job['catalog_name'] + )) + + logger.info(f"Auto-started processing for job {job_id}") + + return UploadResponse( + job_id=job_id, + status='pending_confirmation' if requires_confirmation else 'processing', + detected_columns=detected_columns, + suggested_mapping=suggested_mapping if suggested_mapping.name else None, + requires_confirmation=requires_confirmation + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Unexpected error uploading products: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to upload products: {str(e)}" + ) + + +@app.post( + "/api/v1/products/confirm", + status_code=status.HTTP_202_ACCEPTED, + summary="Confirm field mapping and start processing", + description="Confirm field mapping and start product processing" +) +async def confirm_product_mapping( + job_id: str = Form(...), + catalog_name: str = Form(None), + field_mapping: str = Form(...) # JSON string of FieldMapping +): + """ + Confirm field mapping and start processing. + + Validates field mapping, updates job status, and triggers async processing. + + Args: + job_id (str): Job identifier. + catalog_name (str): Name of the catalog (optional). + field_mapping (str): JSON string representation of FieldMapping object. + + Returns: + dict: Success message and job status. + + Raises: + HTTPException: If job not found or JSON is invalid. + """ + try: + if job_id not in processing_jobs: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Job not found: {job_id}" + ) + + job = processing_jobs[job_id] + if job['status'] != 'pending_confirmation': + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Job {job_id} is not in pending_confirmation status" + ) + + # Parse field mapping + import json + mapping_dict = json.loads(field_mapping) + field_mapping_obj = FieldMapping(**mapping_dict) + + if not field_mapping_obj.name: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Field mapping must include 'name' field" + ) + + # Update job + job['field_mapping'] = field_mapping_obj + job['catalog_name'] = catalog_name or Path(job['filename']).stem + + # Apply field mapping + mapped_products = product_parser.apply_field_mapping( + job['products'], + field_mapping_obj + ) + + # Start async processing + asyncio.create_task(process_products_async( + job_id, + mapped_products, + field_mapping_obj, + job['catalog_name'] + )) + + return { + "message": "Processing started", + "job_id": job_id, + "status": "processing" + } + + except HTTPException: + raise + except json.JSONDecodeError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid JSON in field_mapping: {str(e)}" + ) + except Exception as e: + logger.error(f"Error confirming mapping: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to confirm mapping: {str(e)}" + ) + + +@app.get( + "/api/v1/products/status/{job_id}", + response_model=ProcessingStatus, + status_code=status.HTTP_200_OK, + summary="Get processing status", + description="Check the processing status of a product upload job" +) +async def get_product_status(job_id: str): + """ + Get product processing status. + + Args: + job_id (str): Job identifier. + + Returns: + ProcessingStatus: Current job status and progress. + + Raises: + HTTPException: If job is not found. + """ + if job_id not in processing_jobs: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Job not found: {job_id}" + ) + + job = processing_jobs[job_id] + + return ProcessingStatus( + job_id=job_id, + status=job['status'], + progress=job.get('progress', 0.0), + products_processed=job.get('products_processed', 0), + products_total=job.get('products_total', 0), + current_step=job.get('current_step', 'Unknown'), + errors=job.get('errors', []) + ) + + +@app.delete( + "/api/v1/products/clear", + status_code=status.HTTP_200_OK, + summary="Clear product catalog", + description="Clear all products from the catalog. WARNING: This will delete all indexed products!" +) +async def clear_products(): + """ + Clear all products from the catalog. + + WARNING: Removes all product data from metadata store and indexes. + + Returns: + dict: Success message. + """ + try: + metadata_store.clear_all_products() + index_manager.clear_all() + + logger.warning("All products cleared by user request") + + return { + "message": "All products cleared successfully", + "status": "success" + } + + except Exception as e: + logger.error(f"Error clearing products: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to clear products: {str(e)}" + ) + + +@app.get( + "/api/v1/products/mode", + status_code=status.HTTP_200_OK, + summary="Get system mode", + description="Get current system mode (document or product)" +) +async def get_mode(): + """ + Get current system mode. + + Returns: + dict: Current mode ('document' or 'product') and available modes. + """ + return { + "mode": system_mode, + "available_modes": ["document", "product"] + } + + +@app.post( + "/api/v1/products/mode", + status_code=status.HTTP_200_OK, + summary="Set system mode", + description="Switch between document and product mode" +) +async def set_mode(mode: str = Form(...)): + """ + Set system mode. + + Switches the system between 'document' and 'product' modes, affecting + default search behavior. + + Args: + mode (str): System mode ('document' or 'product'). + + Returns: + dict: Confirmation message. + """ + global system_mode + + if mode not in ['document', 'product']: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid mode: {mode}. Must be 'document' or 'product'" + ) + + system_mode = mode + logger.info(f"System mode changed to: {mode}") + + return { + "message": f"System mode set to {mode}", + "mode": system_mode + } + + +@app.get( + "/api/v1/products/catalog/info", + status_code=status.HTTP_200_OK, + summary="Get catalog info", + description="Get information about the current product catalog" +) +async def get_catalog_info(): + """ + Get catalog information. + + Returns: + dict: Catalog metadata, product counts, price ranges, etc. + """ + catalog_metadata = metadata_store.get_catalog_metadata() + product_stats = metadata_store.get_product_stats() + + if not catalog_metadata: + return { + "loaded": False, + "message": "No catalog loaded" + } + + return { + "loaded": True, + "name": catalog_metadata['catalog_name'], + "product_count": catalog_metadata['product_count'], + "categories": catalog_metadata.get('categories', []), + "price_range": { + "min": catalog_metadata.get('price_range_min'), + "max": catalog_metadata.get('price_range_max') + }, + "upload_date": catalog_metadata['upload_date'], + "stats": product_stats + } + + +@app.get("/", summary="Root endpoint") +async def root(): + """ + Root endpoint with service information. + + Returns: + dict: Basic service info including version, status, and mode. + """ + return { + "service": "Document Ingestion Service", + "version": "1.0.0", + "status": "running", + "mode": system_mode, + "docs": "/docs", + "health": "/health" + } + + +# Application startup/shutdown +# Application startup +if __name__ == "__main__": + import uvicorn + + logger.info(f"Starting Ingestion Service on {settings.ingestion_host}:{settings.ingestion_port}") + logger.info(f"Deployment phase: {settings.deployment_phase}") + logger.info(f"Document storage: {settings.document_storage_path}") + logger.info(f"Index storage: {settings.index_storage_path}") + + uvicorn.run( + app, + host=settings.ingestion_host, # nosec B104 - Binding to all interfaces is intentional for Docker container + port=settings.ingestion_port, + log_level=settings.log_level.lower() + ) + diff --git a/sample_solutions/HybridSearch/api/ingestion/requirements.txt b/sample_solutions/HybridSearch/api/ingestion/requirements.txt new file mode 100644 index 00000000..fe2acf50 --- /dev/null +++ b/sample_solutions/HybridSearch/api/ingestion/requirements.txt @@ -0,0 +1,50 @@ +# Document Ingestion Service Requirements +# API Framework +fastapi==0.104.1 +uvicorn[standard]==0.24.0 +pydantic==2.5.0 +pydantic-settings==2.1.0 + +# Document Processing +pypdf==6.6.0 +pdfplumber==0.10.3 +python-docx==1.1.0 +openpyxl==3.1.5 +python-pptx==1.0.2 + +# Advanced document processing (optional) +# unstructured==0.10.30 +# pdf2image==1.16.3 + +# Text processing +nltk>=3.9 +spacy==3.7.2 + +# Vector search and indexing +faiss-cpu==1.7.4 +rank-bm25==0.2.2 +numpy>=1.24.0 + +# Database +sqlalchemy==2.0.23 +aiosqlite==0.19.0 + +# HTTP Client (for calling embedding service) +httpx==0.25.1 + +# Logging +python-json-logger==2.0.7 + +# Environment +python-dotenv==1.0.0 + +# File handling +python-multipart>=0.0.22 +aiofiles==23.2.1 + +# Product catalog support +pandas>=2.0.0 + +# Utilities +python-magic==0.4.27 + diff --git a/sample_solutions/HybridSearch/api/ingestion/schemas/__init__.py b/sample_solutions/HybridSearch/api/ingestion/schemas/__init__.py new file mode 100644 index 00000000..0b5c5bfb --- /dev/null +++ b/sample_solutions/HybridSearch/api/ingestion/schemas/__init__.py @@ -0,0 +1,2 @@ +"""Product schemas module""" + diff --git a/sample_solutions/HybridSearch/api/ingestion/schemas/product_schemas.py b/sample_solutions/HybridSearch/api/ingestion/schemas/product_schemas.py new file mode 100644 index 00000000..5d4aff76 --- /dev/null +++ b/sample_solutions/HybridSearch/api/ingestion/schemas/product_schemas.py @@ -0,0 +1,114 @@ +""" +Product Data Models +Pydantic models for product data structures +""" + +from typing import Optional, List, Dict, Any +from pydantic import BaseModel, Field +from datetime import datetime + + +class ProductBase(BaseModel): + """Base product model""" + name: str = Field(..., description="Product name/title") + description: Optional[str] = Field(None, description="Product description") + category: Optional[str] = Field(None, description="Product category") + price: Optional[float] = Field(None, description="Product price", ge=0) + rating: Optional[float] = Field(None, description="Product rating (0-5)", ge=0, le=5) + review_count: Optional[int] = Field(None, description="Number of reviews", ge=0) + image_url: Optional[str] = Field(None, description="Product image URL") + brand: Optional[str] = Field(None, description="Product brand") + + +class ProductCreate(ProductBase): + """Product creation model""" + id: Optional[str] = Field(None, description="Product ID (auto-generated if not provided)") + embedding_text: Optional[str] = Field(None, description="Text used for embedding") + + +class ProductInDB(ProductBase): + """Product model as stored in database""" + id: str + embedding_text: Optional[str] = None + created_at: datetime + updated_at: datetime + attributes: Dict[str, str] = Field(default_factory=dict, description="Product attributes") + + class Config: + from_attributes = True + + +class ProductAttribute(BaseModel): + """Product attribute model""" + attribute_name: str = Field(..., description="Attribute name (e.g., 'color', 'size')") + attribute_value: str = Field(..., description="Attribute value") + + +class CatalogMetadata(BaseModel): + """Catalog metadata model""" + catalog_name: str = Field(..., description="Catalog name") + upload_date: datetime + product_count: int = Field(..., description="Number of products", ge=0) + categories: List[str] = Field(default_factory=list, description="List of unique categories") + price_range_min: Optional[float] = Field(None, description="Minimum price in catalog") + price_range_max: Optional[float] = Field(None, description="Maximum price in catalog") + + +class ProductSearchResult(BaseModel): + """Product search result model""" + product_id: str + name: str + description: str = Field(..., description="Truncated description (~200 chars)") + category: Optional[str] = None + price: Optional[float] = None + rating: Optional[float] = None + review_count: Optional[int] = None + image_url: Optional[str] = None + relevance_score: float = Field(..., description="RRF relevance score", ge=0, le=1) + match_reasons: List[str] = Field(default_factory=list, description="Why this product matched") + attributes: Dict[str, str] = Field(default_factory=dict, description="Product attributes") + + +class SearchResponse(BaseModel): + """Search response model""" + query: str + interpreted_as: str = Field(..., description="Cleaned semantic query") + applied_filters: Dict[str, Any] = Field(default_factory=dict, description="Applied filters") + total_matches: int = Field(..., description="Total matches before limit", ge=0) + results: List[ProductSearchResult] + suggested_filters: List[str] = Field(default_factory=list, description="Suggested filter refinements") + + +class FieldMapping(BaseModel): + """Field mapping model for CSV/JSON parsing""" + name: Optional[str] = Field(None, description="Maps to product name") + description: Optional[str] = Field(None, description="Maps to product description") + category: Optional[str] = Field(None, description="Maps to product category") + price: Optional[str] = Field(None, description="Maps to product price") + rating: Optional[str] = Field(None, description="Maps to product rating") + review_count: Optional[str] = Field(None, description="Maps to review count") + image_url: Optional[str] = Field(None, description="Maps to image URL") + brand: Optional[str] = Field(None, description="Maps to product brand") + id: Optional[str] = Field(None, description="Maps to product ID") + + +class UploadResponse(BaseModel): + """Product upload response""" + job_id: str + status: str = Field(..., description="processing, pending_confirmation, or error") + detected_columns: List[str] = Field(default_factory=list, description="Detected CSV/JSON columns") + suggested_mapping: Optional[FieldMapping] = Field(None, description="Suggested field mapping") + requires_confirmation: bool = Field(False, description="Whether user confirmation is needed") + error_message: Optional[str] = Field(None, description="Error message if status is error") + + +class ProcessingStatus(BaseModel): + """Processing status response""" + job_id: str + status: str = Field(..., description="processing, complete, or error") + progress: float = Field(..., description="Progress percentage (0-1)", ge=0, le=1) + products_processed: int = Field(..., description="Number of products processed", ge=0) + products_total: int = Field(..., description="Total number of products", ge=0) + current_step: str = Field(..., description="Current processing step") + errors: List[str] = Field(default_factory=list, description="List of errors encountered") + diff --git a/sample_solutions/HybridSearch/api/ingestion/services/__init__.py b/sample_solutions/HybridSearch/api/ingestion/services/__init__.py new file mode 100644 index 00000000..442810f7 --- /dev/null +++ b/sample_solutions/HybridSearch/api/ingestion/services/__init__.py @@ -0,0 +1,4 @@ +""" +Ingestion Service Modules +""" + diff --git a/sample_solutions/HybridSearch/api/ingestion/services/chunker.py b/sample_solutions/HybridSearch/api/ingestion/services/chunker.py new file mode 100644 index 00000000..80260afe --- /dev/null +++ b/sample_solutions/HybridSearch/api/ingestion/services/chunker.py @@ -0,0 +1,140 @@ +""" +Text Chunker +Splits text into chunks for embedding and retrieval +""" + +import logging +from typing import List, Dict +import re + +logger = logging.getLogger(__name__) + + +class TextChunker: + """Chunk text into smaller pieces for processing""" + + def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50): + """ + Initialize chunker. + + Args: + chunk_size (int): Maximum number of tokens per chunk. + chunk_overlap (int): Number of tokens to overlap between chunks. + """ + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + logger.info(f"TextChunker initialized with chunk_size={chunk_size}, chunk_overlap={chunk_overlap}") + + @staticmethod + def simple_tokenize(text: str) -> List[str]: + """ + Simple tokenization by splitting on whitespace and punctuation marks. + + Args: + text (str): Input text to tokenize. + + Returns: + List[str]: List of token strings. + """ + # Split on whitespace and punctuation + tokens = re.findall(r'\w+|[^\w\s]', text) + return tokens + + def chunk_text(self, text: str, metadata: Dict = None) -> List[Dict]: + """ + Chunk text into overlapping segments based on token count. + + Args: + text (str): Input text to chunk. + metadata (Dict, optional): Optional metadata to attach to each chunk. + + Returns: + List[Dict]: List of chunk dictionaries with text, token counts, and metadata. + """ + if not text.strip(): + return [] + + metadata = metadata or {} + + # Tokenize + tokens = self.simple_tokenize(text) + logger.info(f"Tokenized text into {len(tokens)} tokens") + + if len(tokens) <= self.chunk_size: + # Text is smaller than chunk size, return as single chunk + return [{ + "text": text, + "token_count": len(tokens), + **metadata + }] + + chunks = [] + start = 0 + + while start < len(tokens): + # Get chunk + end = start + self.chunk_size + chunk_tokens = tokens[start:end] + + # Reconstruct text from tokens + chunk_text = " ".join(chunk_tokens) + + chunks.append({ + "text": chunk_text, + "token_count": len(chunk_tokens), + "start_token": start, + "end_token": end, + **metadata + }) + + # Move start position with overlap + start = end - self.chunk_overlap + + # Prevent infinite loop + if start >= len(tokens): + break + + logger.debug(f"Created {len(chunks)} chunks from {len(tokens)} tokens") + return chunks + + def chunk_document( + self, + pages_or_sections: Dict[int, str], + document_id: str + ) -> List[Dict]: + """ + Chunk an entire document preserving page/section context. + + Args: + pages_or_sections (Dict[int, str]): Dictionary mapping page/section numbers to text. + document_id (str): Unique document identifier. + + Returns: + List[Dict]: List of chunk dictionaries with added chunk IDs and metadata. + """ + all_chunks = [] + chunk_id_counter = 0 + + for page_num, text in pages_or_sections.items(): + # Chunk each page/section + page_chunks = self.chunk_text( + text, + metadata={ + "page_number": page_num, + "document_id": document_id + } + ) + + # Add chunk IDs + for chunk in page_chunks: + chunk_id_counter += 1 + chunk["chunk_id"] = f"{document_id}_chunk_{chunk_id_counter}" + all_chunks.append(chunk) + + logger.info( + f"Document {document_id}: Created {len(all_chunks)} chunks " + f"from {len(pages_or_sections)} pages/sections" + ) + + return all_chunks + diff --git a/sample_solutions/HybridSearch/api/ingestion/services/document_parser.py b/sample_solutions/HybridSearch/api/ingestion/services/document_parser.py new file mode 100644 index 00000000..0df6e48b --- /dev/null +++ b/sample_solutions/HybridSearch/api/ingestion/services/document_parser.py @@ -0,0 +1,228 @@ +""" +Document Parser +Handles extraction of text from various document formats +""" + +import logging +from pathlib import Path +from typing import Dict, List, Optional +from pypdf import PdfReader # Modern replacement for PyPDF2 +from docx import Document +import openpyxl +from pptx import Presentation + +logger = logging.getLogger(__name__) + + +class DocumentParser: + """ + Parse documents and extract text content. + + Supports multiple file formats (PDF, DOCX, XLSX, PPTX, TXT) and + extracts text into page/section dictionaries. + """ + + @staticmethod + def parse_pdf(file_path: Path) -> Dict[int, str]: + """ + Parse PDF and extract text by page. + + Args: + file_path (Path): Path to the PDF file. + + Returns: + Dict[int, str]: Dictionary mapping page numbers (1-based) to text content. + + Raises: + ValueError: If parsing fails. + """ + try: + pages = {} + with open(file_path, 'rb') as file: + pdf_reader = PdfReader(file) + for page_num, page in enumerate(pdf_reader.pages, 1): + text = page.extract_text() + if text.strip(): + pages[page_num] = text + + logger.info(f"Extracted {len(pages)} pages from PDF: {file_path.name}") + return pages + + except Exception as e: + logger.error(f"Error parsing PDF {file_path}: {e}") + raise ValueError(f"Failed to parse PDF: {str(e)}") + + @staticmethod + def parse_docx(file_path: Path) -> Dict[int, str]: + """ + Parse DOCX and extract text by paragraph groupings (sections). + + Args: + file_path (Path): Path to the DOCX file. + + Returns: + Dict[int, str]: Dictionary mapping section numbers to text content. + + Raises: + ValueError: If parsing fails. + """ + try: + doc = Document(file_path) + sections = {} + section_num = 1 + current_text = [] + + for para in doc.paragraphs: + text = para.text.strip() + if text: + current_text.append(text) + # Group paragraphs into sections (every ~10 paragraphs) + if len(current_text) >= 10: + sections[section_num] = "\n".join(current_text) + current_text = [] + section_num += 1 + + # Add remaining text + if current_text: + sections[section_num] = "\n".join(current_text) + + logger.info(f"Extracted {len(sections)} sections from DOCX: {file_path.name}") + return sections + + except Exception as e: + logger.error(f"Error parsing DOCX {file_path}: {e}") + raise ValueError(f"Failed to parse DOCX: {str(e)}") + + @staticmethod + def parse_xlsx(file_path: Path) -> Dict[int, str]: + """ + Parse XLSX and extract text from sheets. + + Args: + file_path (Path): Path to the XLSX file. + + Returns: + Dict[int, str]: Dictionary mapping sheet numbers (1-based) to text content. + + Raises: + ValueError: If parsing fails. + """ + try: + workbook = openpyxl.load_workbook(file_path, data_only=True) + sheets = {} + + for sheet_num, sheet_name in enumerate(workbook.sheetnames, 1): + sheet = workbook[sheet_name] + rows = [] + + for row in sheet.iter_rows(values_only=True): + row_text = " | ".join(str(cell) if cell is not None else "" for cell in row) + if row_text.strip(): + rows.append(row_text) + + if rows: + sheets[sheet_num] = f"Sheet: {sheet_name}\n" + "\n".join(rows) + + logger.info(f"Extracted {len(sheets)} sheets from XLSX: {file_path.name}") + return sheets + + except Exception as e: + logger.error(f"Error parsing XLSX {file_path}: {e}") + raise ValueError(f"Failed to parse XLSX: {str(e)}") + + @staticmethod + def parse_pptx(file_path: Path) -> Dict[int, str]: + """ + Parse PPTX and extract text from slides. + + Args: + file_path (Path): Path to the PPTX file. + + Returns: + Dict[int, str]: Dictionary mapping slide numbers (1-based) to text content. + + Raises: + ValueError: If parsing fails. + """ + try: + prs = Presentation(file_path) + slides = {} + + for slide_num, slide in enumerate(prs.slides, 1): + text_parts = [] + + for shape in slide.shapes: + if hasattr(shape, "text"): + text = shape.text.strip() + if text: + text_parts.append(text) + + if text_parts: + slides[slide_num] = "\n".join(text_parts) + + logger.info(f"Extracted {len(slides)} slides from PPTX: {file_path.name}") + return slides + + except Exception as e: + logger.error(f"Error parsing PPTX {file_path}: {e}") + raise ValueError(f"Failed to parse PPTX: {str(e)}") + + @staticmethod + def parse_txt(file_path: Path) -> Dict[int, str]: + """ + Parse TXT file. + + Args: + file_path (Path): Path to the TXT file. + + Returns: + Dict[int, str]: Dictionary with a single entry {1: text_content}. + + Raises: + ValueError: If parsing fails. + """ + try: + with open(file_path, 'r', encoding='utf-8') as file: + text = file.read() + + if text.strip(): + return {1: text} + return {} + + except Exception as e: + logger.error(f"Error parsing TXT {file_path}: {e}") + raise ValueError(f"Failed to parse TXT: {str(e)}") + + def parse_document(self, file_path: Path, file_type: str) -> Dict[int, str]: + """ + Parse document based on file type. + + Dispatcher method that calls the appropriate parser. + + Args: + file_path (Path): Path to document file. + file_type (str): Type of file ('pdf', 'docx', 'xlsx', 'pptx', 'txt'). + + Returns: + Dict[int, str]: Dictionary mapping page/section numbers to text content. + + Raises: + ValueError: If file type is unsupported. + """ + file_type = file_type.lower().strip() + + parsers = { + 'pdf': self.parse_pdf, + 'docx': self.parse_docx, + 'xlsx': self.parse_xlsx, + 'pptx': self.parse_pptx, + 'ppt': self.parse_pptx, # Alias + 'txt': self.parse_txt + } + + parser = parsers.get(file_type) + if not parser: + raise ValueError(f"Unsupported file type: {file_type}") + + return parser(file_path) + diff --git a/sample_solutions/HybridSearch/api/ingestion/services/index_manager.py b/sample_solutions/HybridSearch/api/ingestion/services/index_manager.py new file mode 100644 index 00000000..acb6daae --- /dev/null +++ b/sample_solutions/HybridSearch/api/ingestion/services/index_manager.py @@ -0,0 +1,481 @@ +""" +Index Manager +Manages FAISS vector index and BM25 sparse index +""" + +import logging +import pickle +from pathlib import Path +from typing import List, Dict, Optional +import numpy as np +import faiss +from rank_bm25 import BM25Okapi + +logger = logging.getLogger(__name__) + + +class IndexManager: + """ + Manage FAISS and BM25 indexes. + + Handles creation, loading, saving, and updating of vector (FAISS) + and sparse (BM25) indexes for both documents and products. + """ + + def __init__( + self, + index_storage_path: str, + embedding_dim: int = 768 + ): + """ + Initialize index manager. + + Args: + index_storage_path (str): Path to store index files. + embedding_dim (int): Dimension of embeddings (default: 768). + """ + self.index_storage_path = Path(index_storage_path) + self.index_storage_path.mkdir(parents=True, exist_ok=True) + + self.embedding_dim = embedding_dim + + # Document index paths + self.faiss_index_path = self.index_storage_path / "faiss_index.bin" + self.bm25_index_path = self.index_storage_path / "bm25_index.pkl" + self.metadata_path = self.index_storage_path / "metadata.pkl" + self.filters_cache_path = self.index_storage_path / "filters_cache.pkl" + + # Product index paths (separate from documents) + self.product_faiss_index_path = self.index_storage_path / "product_faiss_index.bin" + self.product_bm25_index_path = self.index_storage_path / "product_bm25_index.pkl" + self.product_metadata_path = self.index_storage_path / "product_metadata.pkl" + self.product_filters_cache_path = self.index_storage_path / "product_filters_cache.pkl" + + # Initialize document indexes + self.faiss_index = None + self.bm25_index = None + self.metadata = [] + + # Initialize product indexes (separate) + self.product_faiss_index = None + self.product_bm25_index = None + self.product_metadata = [] + + # Product-specific: filters cache for fast filtering + self.filters_cache = { + "prices": [], + "categories": [], + "ratings": [], + "product_ids": [] + } + + # Load existing indexes if available + self._load_indexes() + self._load_product_indexes() + + def _load_indexes(self): + """ + Load existing indexes from disk. + + Initializes FAISS and BM25 indexes and loads metadata. + Creates fresh indexes if files are missing or corrupt. + """ + try: + # Load FAISS index + if self.faiss_index_path.exists(): + self.faiss_index = faiss.read_index(str(self.faiss_index_path)) + logger.info(f"Loaded FAISS index with {self.faiss_index.ntotal} vectors (d={self.faiss_index.d})") + else: + # Create new FAISS index (IndexFlatIP for cosine similarity) + self.faiss_index = faiss.IndexFlatIP(self.embedding_dim) + logger.info(f"Created new FAISS index (d={self.embedding_dim})") + + # Load BM25 index + if self.bm25_index_path.exists(): + with open(self.bm25_index_path, 'rb') as f: + self.bm25_index = pickle.load(f) # nosec B301 - indexes are written by this application + logger.info("Loaded BM25 index") + + # Load metadata + if self.metadata_path.exists(): + with open(self.metadata_path, 'rb') as f: + self.metadata = pickle.load(f) # nosec B301 - indexes are written by this application + logger.info(f"Loaded {len(self.metadata)} metadata entries") + + # Load filters cache (for products) + if self.filters_cache_path.exists(): + with open(self.filters_cache_path, 'rb') as f: + self.filters_cache = pickle.load(f) # nosec B301 - indexes are written by this application + logger.info(f"Loaded filters cache with {len(self.filters_cache.get('product_ids', []))} products") + + except Exception as e: + logger.error(f"Error loading indexes: {e}") + # Initialize new indexes + self.faiss_index = faiss.IndexFlatIP(self.embedding_dim) + self.bm25_index = None + self.metadata = [] + self.filters_cache = { + "prices": [], + "categories": [], + "ratings": [], + "product_ids": [] + } + + def _load_product_indexes(self): + """Load existing product indexes from disk""" + try: + if self.product_faiss_index_path.exists(): + self.product_faiss_index = faiss.read_index(str(self.product_faiss_index_path)) + logger.info(f"Loaded product FAISS index with {self.product_faiss_index.ntotal} vectors") + else: + self.product_faiss_index = faiss.IndexFlatIP(self.embedding_dim) + logger.info("Created new product FAISS index") + + if self.product_bm25_index_path.exists(): + with open(self.product_bm25_index_path, 'rb') as f: + self.product_bm25_index = pickle.load(f) # nosec B301 - indexes are written by this application + logger.info("Loaded product BM25 index") + + if self.product_metadata_path.exists(): + with open(self.product_metadata_path, 'rb') as f: + self.product_metadata = pickle.load(f) # nosec B301 - indexes are written by this application + logger.info(f"Loaded {len(self.product_metadata)} product metadata entries") + + if self.product_filters_cache_path.exists(): + with open(self.product_filters_cache_path, 'rb') as f: + product_filters = pickle.load(f) # nosec B301 - indexes are written by this application + self.filters_cache.update(product_filters) + logger.debug("Loaded product filters cache") + + except Exception as e: + logger.error(f"Error loading product indexes: {e}") + self.product_faiss_index = faiss.IndexFlatIP(self.embedding_dim) + self.product_bm25_index = None + self.product_metadata = [] + + def _save_indexes(self): + """ + Save indexes to disk. + + Persists FAISS index, BM25 index, metadata, and cache to configured storage paths. + + Raises: + IOError: If saving fails. + """ + try: + # Save FAISS index + faiss.write_index(self.faiss_index, str(self.faiss_index_path)) + logger.info(f"Saved FAISS index ({self.faiss_index.ntotal} vectors)") + + # Save BM25 index + if self.bm25_index: + with open(self.bm25_index_path, 'wb') as f: + pickle.dump(self.bm25_index, f) + logger.info("Saved BM25 index") + + # Save metadata + with open(self.metadata_path, 'wb') as f: + pickle.dump(self.metadata, f) + logger.info(f"Saved {len(self.metadata)} metadata entries") + + # Save filters cache (for products) + with open(self.filters_cache_path, 'wb') as f: + pickle.dump(self.filters_cache, f) + logger.debug("Saved filters cache") + + except Exception as e: + logger.error(f"Error saving indexes: {e}", exc_info=True) + raise + + def add_chunks( + self, + chunks: List[Dict], + embeddings: List[List[float]], + content_type: str = "document" + ): + """ + Add chunks to indexes. + + Updates both FAISS (dense) and BM25 (sparse) indexes with new chunks. + + Args: + chunks (List[Dict]): List of chunk dictionaries with text and metadata. + embeddings (List[List[float]]): List of embedding vectors corresponding to chunks. + content_type (str): Type of content ("document" or "product"). + + Raises: + ValueError: If chunks and embeddings counts mismatch. + """ + if len(chunks) != len(embeddings): + raise ValueError("Number of chunks must match number of embeddings") + + if not chunks: + logger.warning("No chunks to add") + return + + # Add to FAISS index + embeddings_array = np.array(embeddings, dtype=np.float32) + + # Normalize vectors for cosine similarity + faiss.normalize_L2(embeddings_array) + + # Add debug logging for dimensions + logger.info(f"Adding embeddings to FAISS index: embedding_dim={embeddings_array.shape[1]}, index_dim={self.faiss_index.d}") + + # Add to index + self.faiss_index.add(embeddings_array) + + # Add metadata with content_type + for chunk in chunks: + chunk['content_type'] = content_type + self.metadata.append(chunk) + + # Rebuild BM25 index with all texts + all_texts = [chunk["text"] for chunk in self.metadata] + tokenized_corpus = [text.lower().split() for text in all_texts] + self.bm25_index = BM25Okapi(tokenized_corpus) + + logger.info(f"Added {len(chunks)} chunks to indexes (content_type={content_type})") + + # Save indexes + self._save_indexes() + + def get_stats(self) -> Dict: + """ + Get index statistics. + + Returns: + Dict: Dictionary containing detailed index stats (counts, dims, etc.). + """ + return { + "total_chunks": len(self.metadata), + "faiss_vectors": self.faiss_index.ntotal if self.faiss_index else 0, + "bm25_enabled": self.bm25_index is not None, + "embedding_dim": self.embedding_dim + } + + def delete_document(self, document_id: str): + """ + Delete all chunks for a document. + + Removed chunks from metadata and rebuilds indexes (expensive operation). + + Args: + document_id (str): Document ID to delete. + """ + # Find indices to remove + indices_to_remove = [] + new_metadata = [] + + for idx, chunk in enumerate(self.metadata): + if chunk.get("document_id") == document_id: + indices_to_remove.append(idx) + else: + new_metadata.append(chunk) + + if not indices_to_remove: + logger.warning(f"No chunks found for document {document_id}") + return + + # For simplicity, rebuild indexes without deleted chunks + # In production, consider more efficient index update strategies + logger.info(f"Rebuilding indexes after deleting {len(indices_to_remove)} chunks") + + # Rebuild FAISS index + self.faiss_index = faiss.IndexFlatIP(self.embedding_dim) + + # Add back remaining chunks (this requires re-embedding, which we skip for now) + # In production, store embeddings with metadata and rebuild from those + self.metadata = new_metadata + + # Rebuild BM25 + if new_metadata: + all_texts = [chunk["text"] for chunk in new_metadata] + tokenized_corpus = [text.lower().split() for text in all_texts] + self.bm25_index = BM25Okapi(tokenized_corpus) + else: + self.bm25_index = None + + # Save indexes + self._save_indexes() + + logger.info(f"Deleted document {document_id}") + + def clear_all(self): + """Clear all indexes""" + self.faiss_index = faiss.IndexFlatIP(self.embedding_dim) + self.bm25_index = None + self.metadata = [] + + # Clear product indexes + self.product_faiss_index = faiss.IndexFlatIP(self.embedding_dim) + self.product_bm25_index = None + self.product_metadata = [] + + self.filters_cache = { + "prices": [], + "categories": [], + "ratings": [], + "product_ids": [] + } + + self._save_indexes() + self._save_product_indexes() + logger.info("Cleared all indexes (documents and products)") + + def clear_products_only(self): + """ + Clear only products from the unified index, keeping documents intact. + + Filters metadata to remove product entries and rebuilds BM25. + Note: FAISS vectors remain but metadata filtering ensures only current content is returned. + """ + # Count existing products before clearing + product_count = sum(1 for chunk in self.metadata if chunk.get('content_type') == 'product') + + # Filter out products, keep only documents + self.metadata = [chunk for chunk in self.metadata if chunk.get('content_type') != 'product'] + + # Rebuild BM25 index with remaining content (documents only) + if self.metadata: + all_texts = [chunk["text"] for chunk in self.metadata] + tokenized_corpus = [text.lower().split() for text in all_texts] + self.bm25_index = BM25Okapi(tokenized_corpus) + else: + self.bm25_index = None + + # Clear filters cache (product-specific) + self.filters_cache = { + "prices": [], + "categories": [], + "ratings": [], + "product_ids": [] + } + + # Save the updated metadata and BM25 index + # Note: FAISS index keeps old vectors but metadata filtering ensures correctness + self._save_indexes() + + logger.info(f"Cleared {product_count} products from unified index. {len(self.metadata)} documents remain") + + def add_products( + self, + products: List[Dict], + embeddings: List[List[float]] + ): + """ + Add products to unified indexes (uses main index with content_type="product"). + + Args: + products (List[Dict]): List of product dictionaries with metadata. + embeddings (List[List[float]]): List of embedding vectors (one per product). + + Raises: + ValueError: If products and embeddings counts mismatch. + """ + if len(products) != len(embeddings): + raise ValueError("Number of products must match number of embeddings") + + if not products: + logger.warning("No products to add") + return + + # Create chunks with product metadata and content_type + chunks = [] + for product in products: + chunk = { + 'chunk_id': f"{product.get('product_id', product.get('id'))}_chunk", + 'document_id': product.get('product_id', product.get('id')), + 'text': product.get('embedding_text', ''), + 'content_type': 'product', # Mark as product + 'metadata': { + 'product_id': product.get('product_id', product.get('id')), + 'name': product.get('name'), + 'category': product.get('category'), + 'price': product.get('price'), + 'rating': product.get('rating'), + 'brand': product.get('brand'), + 'description': product.get('description'), + 'review_count': product.get('review_count'), + 'image_url': product.get('image_url') + } + } + chunks.append(chunk) + + # Update filters cache for fast filtering + product_id = product.get('product_id', product.get('id')) + self.filters_cache['product_ids'].append(product_id) + self.filters_cache['prices'].append(product.get('price')) + self.filters_cache['categories'].append(product.get('category')) + self.filters_cache['ratings'].append(product.get('rating')) + + # Add to UNIFIED FAISS index (same as documents) + embeddings_array = np.array(embeddings, dtype=np.float32) + faiss.normalize_L2(embeddings_array) + self.faiss_index.add(embeddings_array) + + # Add metadata to unified metadata list + for chunk in chunks: + self.metadata.append(chunk) + + # Rebuild unified BM25 index with all content (documents + products) + all_texts = [chunk["text"] for chunk in self.metadata] + tokenized_corpus = [text.lower().split() for text in all_texts] + self.bm25_index = BM25Okapi(tokenized_corpus) + + logger.info(f"Added {len(products)} products to unified indexes (content_type=product)") + + # Save unified indexes + self._save_indexes() + + def _save_product_indexes(self): + """Save product indexes to disk""" + try: + # Save product FAISS index + faiss.write_index(self.product_faiss_index, str(self.product_faiss_index_path)) + logger.info(f"Saved product FAISS index ({self.product_faiss_index.ntotal} vectors)") + + # Save product BM25 index + if self.product_bm25_index: + with open(self.product_bm25_index_path, 'wb') as f: + pickle.dump(self.product_bm25_index, f) + logger.info("Saved product BM25 index") + + # Save product metadata + with open(self.product_metadata_path, 'wb') as f: + pickle.dump(self.product_metadata, f) + logger.info(f"Saved {len(self.product_metadata)} product metadata entries") + + # Save product filters cache + with open(self.product_filters_cache_path, 'wb') as f: + pickle.dump(self.filters_cache, f) + logger.debug("Saved product filters cache") + + except Exception as e: + logger.error(f"Error saving product indexes: {e}") + raise + + def get_filters_cache(self) -> Dict: + """ + Get filters cache for fast filtering + + Returns: + Dictionary with filterable fields aligned with vector indices + """ + return self.filters_cache.copy() + + def get_product_metadata(self, product_id: str) -> Optional[Dict]: + """ + Get product metadata by product_id. + + Args: + product_id (str): Product identifier. + + Returns: + Optional[Dict]: Product metadata dictionary or None if not found. + """ + for chunk in self.metadata: + if chunk.get('metadata', {}).get('product_id') == product_id: + return chunk.get('metadata') + return None + diff --git a/sample_solutions/HybridSearch/api/ingestion/services/metadata_store.py b/sample_solutions/HybridSearch/api/ingestion/services/metadata_store.py new file mode 100644 index 00000000..adc62ce8 --- /dev/null +++ b/sample_solutions/HybridSearch/api/ingestion/services/metadata_store.py @@ -0,0 +1,651 @@ +""" +Metadata Store +SQLite database for document metadata and processing status +""" + +import logging +import sqlite3 +import os +from pathlib import Path +from typing import Dict, List, Optional +from datetime import datetime +import json + +logger = logging.getLogger(__name__) + + +class MetadataStore: + """ + SQLite-based metadata storage. + + Manages persistence for document and product metadata, processing status, + and catalog statistics using a local SQLite database. + """ + + def __init__(self, db_path: str): + """ + Initialize metadata store. + + Args: + db_path (str): Path to SQLite database file. + """ + self.db_path = Path(db_path) + self.db_path.parent.mkdir(parents=True, exist_ok=True) + + self.conn = None + self._connect() + self._initialize_schema() + + def _connect(self): + """ + Connect to database. + + Establishes connection, enables WAL mode for concurrency, and configures + synchronous mode for performance. + + Raises: + PermissionError: If database file is read-only. + sqlite3.Error: For other database connection issues. + """ + try: + # Ensure parent directory exists and is writable + self.db_path.parent.mkdir(parents=True, exist_ok=True) + + # Check if database file exists and is writable + if self.db_path.exists(): + if not os.access(self.db_path, os.W_OK): + logger.warning(f"Database file {self.db_path} is not writable, attempting to fix permissions") + import stat + self.db_path.chmod(stat.S_IWRITE | stat.S_IREAD) + + # Connect with timeout to handle locked database + # Use default isolation level (not autocommit) for transaction safety + self.conn = sqlite3.connect( + str(self.db_path), + check_same_thread=False, + timeout=10.0 # 10 second timeout for locked database + ) + self.conn.row_factory = sqlite3.Row # Return rows as dictionaries + + # Enable WAL mode for better concurrency (allows multiple readers + one writer) + cursor = self.conn.cursor() + try: + cursor.execute("PRAGMA journal_mode=WAL") + result = cursor.fetchone() + logger.info(f"Database journal mode: {result[0] if result else 'unknown'}") + except Exception as e: + logger.warning(f"Could not enable WAL mode: {e}") + + try: + cursor.execute("PRAGMA synchronous=NORMAL") # Faster writes, still safe + except Exception as e: + logger.warning(f"Could not set synchronous mode: {e}") + + self.conn.commit() + + logger.info(f"Connected to database: {self.db_path}") + except sqlite3.OperationalError as e: + if "readonly" in str(e).lower() or "read-only" in str(e).lower(): + logger.error(f"Database is read-only: {self.db_path}. Check file permissions.") + raise PermissionError(f"Database file is read-only: {self.db_path}") + raise + except Exception as e: + logger.error(f"Failed to connect to database {self.db_path}: {e}") + raise + + def _initialize_schema(self): + """ + Initialize database schema. + + Creates tables for documents, products, product attributes, and catalog metadata + if they do not already exist. Also sets up performance indexes. + """ + cursor = self.conn.cursor() + + # Documents table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS documents ( + document_id TEXT PRIMARY KEY, + filename TEXT NOT NULL, + file_type TEXT NOT NULL, + file_size INTEGER, + upload_timestamp TEXT NOT NULL, + processing_status TEXT NOT NULL, + chunk_count INTEGER DEFAULT 0, + error_message TEXT, + metadata TEXT + ) + """) + + # Products table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS products ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT, + category TEXT, + price REAL, + rating REAL, + review_count INTEGER, + image_url TEXT, + brand TEXT, + embedding_text TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Product attributes table (flexible key-value pairs) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS product_attributes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + product_id TEXT NOT NULL, + attribute_name TEXT NOT NULL, + attribute_value TEXT, + FOREIGN KEY (product_id) REFERENCES products(id) ON DELETE CASCADE, + UNIQUE(product_id, attribute_name) + ) + """) + + # Catalog metadata table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS catalog_metadata ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + catalog_name TEXT NOT NULL, + upload_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + product_count INTEGER DEFAULT 0, + categories TEXT, + price_range_min REAL, + price_range_max REAL + ) + """) + + # Create indexes for better query performance + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_products_category ON products(category) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_products_price ON products(price) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_products_rating ON products(rating) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_product_attributes_product_id + ON product_attributes(product_id) + """) + + # Commit explicitly (even though we're in autocommit mode, this ensures it's written) + self.conn.commit() + logger.info("Database schema initialized") + + def add_document( + self, + document_id: str, + filename: str, + file_type: str, + file_size: int, + metadata: Dict = None + ): + """ + Add new document record. + + Args: + document_id (str): Unique document identifier. + filename (str): Original filename. + file_type (str): File extension/type. + file_size (int): File size in bytes. + metadata (Dict, optional): Additional metadata dictionary. + """ + cursor = self.conn.cursor() + + cursor.execute(""" + INSERT INTO documents + (document_id, filename, file_type, file_size, upload_timestamp, + processing_status, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, ( + document_id, + filename, + file_type, + file_size, + datetime.utcnow().isoformat(), + "pending", + json.dumps(metadata or {}) + )) + + self.conn.commit() + logger.info(f"Added document: {document_id}") + + def update_status( + self, + document_id: str, + status: str, + chunk_count: int = None, + error_message: str = None + ): + """ + Update document processing status. + + Args: + document_id (str): Document identifier. + status (str): New status ('pending', 'processing', 'completed', 'failed'). + chunk_count (int, optional): Number of chunks created. + error_message (str, optional): Error message if failed. + """ + cursor = self.conn.cursor() + + query = "UPDATE documents SET processing_status = ?" + params = [status] + + if chunk_count is not None: + query += ", chunk_count = ?" + params.append(chunk_count) + + if error_message is not None: + query += ", error_message = ?" + params.append(error_message) + + query += " WHERE document_id = ?" + params.append(document_id) + + cursor.execute(query, params) + self.conn.commit() + + logger.info(f"Updated document {document_id} status to: {status}") + + def get_document(self, document_id: str) -> Optional[Dict]: + """ + Get document by ID + + Args: + document_id: Document identifier + + Returns: + Document dictionary or None + """ + cursor = self.conn.cursor() + cursor.execute( + "SELECT * FROM documents WHERE document_id = ?", + (document_id,) + ) + row = cursor.fetchone() + + if row: + doc = dict(row) + doc['metadata'] = json.loads(doc['metadata']) + return doc + return None + + def list_documents( + self, + status: str = None, + limit: int = 100 + ) -> List[Dict]: + """ + List documents. + + Args: + status (str, optional): Filter by processing status. + limit (int): Maximum number of documents to return. + + Returns: + List[Dict]: List of document dictionaries. + """ + cursor = self.conn.cursor() + + if status: + cursor.execute( + "SELECT * FROM documents WHERE processing_status = ? " + "ORDER BY upload_timestamp DESC LIMIT ?", + (status, limit) + ) + else: + cursor.execute( + "SELECT * FROM documents ORDER BY upload_timestamp DESC LIMIT ?", + (limit,) + ) + + rows = cursor.fetchall() + documents = [] + + for row in rows: + doc = dict(row) + doc['metadata'] = json.loads(doc['metadata']) + documents.append(doc) + + return documents + + def delete_document(self, document_id: str): + """ + Delete document record + + Args: + document_id: Document identifier + """ + cursor = self.conn.cursor() + cursor.execute( + "DELETE FROM documents WHERE document_id = ?", + (document_id,) + ) + self.conn.commit() + logger.info(f"Deleted document: {document_id}") + + def get_stats(self) -> Dict: + """ + Get database statistics. + + Returns: + Dict: Dictionary with total documents, status breakdowns, and chunk counts. + """ + cursor = self.conn.cursor() + + cursor.execute("SELECT COUNT(*) as total FROM documents") + total = cursor.fetchone()['total'] + + cursor.execute( + "SELECT processing_status, COUNT(*) as count " + "FROM documents GROUP BY processing_status" + ) + status_counts = {row['processing_status']: row['count'] for row in cursor.fetchall()} + + cursor.execute("SELECT SUM(chunk_count) as total_chunks FROM documents") + total_chunks = cursor.fetchone()['total_chunks'] or 0 + + return { + "total_documents": total, + "status_counts": status_counts, + "total_chunks": total_chunks + } + + def clear_all(self): + """ + Clear all documents from the database + + WARNING: This operation cannot be undone! + """ + cursor = self.conn.cursor() + cursor.execute("DELETE FROM documents") + self.conn.commit() + logger.warning("All documents cleared from database") + + def close(self): + """Close database connection""" + if self.conn: + self.conn.close() + logger.info("Database connection closed") + + # Product-related methods + def add_product( + self, + product_id: str, + name: str, + description: str = None, + category: str = None, + price: float = None, + rating: float = None, + review_count: int = None, + image_url: str = None, + brand: str = None, + embedding_text: str = None + ): + """ + Add new product record + + Args: + product_id: Unique product identifier + name: Product name/title + description: Product description + category: Product category + price: Product price + rating: Product rating (0-5) + review_count: Number of reviews + image_url: Product image URL + brand: Product brand + embedding_text: Text used for embedding + """ + cursor = self.conn.cursor() + + cursor.execute(""" + INSERT OR REPLACE INTO products + (id, name, description, category, price, rating, review_count, + image_url, brand, embedding_text, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP) + """, ( + product_id, + name, + description, + category, + price, + rating, + review_count, + image_url, + brand, + embedding_text + )) + + self.conn.commit() + logger.debug(f"Added product: {product_id}") + + def add_product_attribute( + self, + product_id: str, + attribute_name: str, + attribute_value: str + ): + """ + Add product attribute + + Args: + product_id: Product identifier + attribute_name: Attribute name (e.g., "color", "size") + attribute_value: Attribute value + """ + cursor = self.conn.cursor() + + cursor.execute(""" + INSERT OR REPLACE INTO product_attributes + (product_id, attribute_name, attribute_value) + VALUES (?, ?, ?) + """, (product_id, attribute_name, attribute_value)) + + self.conn.commit() + + def get_product(self, product_id: str) -> Optional[Dict]: + """ + Get product by ID + + Args: + product_id: Product identifier + + Returns: + Product dictionary or None + """ + cursor = self.conn.cursor() + cursor.execute( + "SELECT * FROM products WHERE id = ?", + (product_id,) + ) + row = cursor.fetchone() + + if row: + product = dict(row) + # Get attributes + cursor.execute( + "SELECT attribute_name, attribute_value FROM product_attributes WHERE product_id = ?", + (product_id,) + ) + attributes = {row['attribute_name']: row['attribute_value'] for row in cursor.fetchall()} + product['attributes'] = attributes + return product + return None + + def list_products( + self, + category: str = None, + price_min: float = None, + price_max: float = None, + rating_min: float = None, + limit: int = 100, + offset: int = 0 + ) -> List[Dict]: + """ + List products with optional filters + + Args: + category: Filter by category + price_min: Minimum price + price_max: Maximum price + rating_min: Minimum rating + limit: Maximum number of products to return + offset: Offset for pagination + + Returns: + List of product dictionaries + """ + cursor = self.conn.cursor() + + query = "SELECT * FROM products WHERE 1=1" + params = [] + + if category: + query += " AND category = ?" + params.append(category) + + if price_min is not None: + query += " AND price >= ?" + params.append(price_min) + + if price_max is not None: + query += " AND price <= ?" + params.append(price_max) + + if rating_min is not None: + query += " AND rating >= ?" + params.append(rating_min) + + query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" + params.extend([limit, offset]) + + cursor.execute(query, params) + rows = cursor.fetchall() + + products = [] + for row in rows: + product = dict(row) + # Get attributes for each product + cursor.execute( + "SELECT attribute_name, attribute_value FROM product_attributes WHERE product_id = ?", + (product['id'],) + ) + attributes = {row['attribute_name']: row['attribute_value'] for row in cursor.fetchall()} + product['attributes'] = attributes + products.append(product) + + return products + + def get_product_stats(self) -> Dict: + """ + Get product statistics + + Returns: + Dictionary with product statistics + """ + cursor = self.conn.cursor() + + cursor.execute("SELECT COUNT(*) as total FROM products") + total = cursor.fetchone()['total'] + + cursor.execute("SELECT COUNT(DISTINCT category) as categories FROM products WHERE category IS NOT NULL") + categories = cursor.fetchone()['categories'] + + cursor.execute("SELECT MIN(price) as min_price, MAX(price) as max_price, AVG(price) as avg_price FROM products WHERE price IS NOT NULL") + price_stats = cursor.fetchone() + + cursor.execute("SELECT MIN(rating) as min_rating, MAX(rating) as max_rating, AVG(rating) as avg_rating FROM products WHERE rating IS NOT NULL") + rating_stats = cursor.fetchone() + + return { + "total_products": total, + "total_categories": categories, + "price_range": { + "min": price_stats['min_price'], + "max": price_stats['max_price'], + "avg": price_stats['avg_price'] + } if price_stats['min_price'] else None, + "rating_range": { + "min": rating_stats['min_rating'], + "max": rating_stats['max_rating'], + "avg": rating_stats['avg_rating'] + } if rating_stats['min_rating'] else None + } + + def update_catalog_metadata( + self, + catalog_name: str, + product_count: int, + categories: List[str] = None, + price_range_min: float = None, + price_range_max: float = None + ): + """ + Update catalog metadata + + Args: + catalog_name: Name of the catalog + product_count: Number of products + categories: List of unique categories + price_range_min: Minimum price in catalog + price_range_max: Maximum price in catalog + """ + cursor = self.conn.cursor() + + # Clear existing catalog metadata (single catalog mode) + cursor.execute("DELETE FROM catalog_metadata") + + cursor.execute(""" + INSERT INTO catalog_metadata + (catalog_name, product_count, categories, price_range_min, price_range_max) + VALUES (?, ?, ?, ?, ?) + """, ( + catalog_name, + product_count, + json.dumps(categories) if categories else None, + price_range_min, + price_range_max + )) + + self.conn.commit() + logger.info(f"Updated catalog metadata: {catalog_name} ({product_count} products)") + + def get_catalog_metadata(self) -> Optional[Dict]: + """ + Get current catalog metadata + + Returns: + Catalog metadata dictionary or None + """ + cursor = self.conn.cursor() + cursor.execute("SELECT * FROM catalog_metadata ORDER BY upload_date DESC LIMIT 1") + row = cursor.fetchone() + + if row: + metadata = dict(row) + if metadata.get('categories'): + metadata['categories'] = json.loads(metadata['categories']) + return metadata + return None + + def clear_all_products(self): + """ + Clear all products from the database + + WARNING: This operation cannot be undone! + """ + cursor = self.conn.cursor() + cursor.execute("DELETE FROM products") + cursor.execute("DELETE FROM product_attributes") + cursor.execute("DELETE FROM catalog_metadata") + self.conn.commit() + logger.warning("All products cleared from database") + diff --git a/sample_solutions/HybridSearch/api/ingestion/services/product_parser.py b/sample_solutions/HybridSearch/api/ingestion/services/product_parser.py new file mode 100644 index 00000000..d81e8bd3 --- /dev/null +++ b/sample_solutions/HybridSearch/api/ingestion/services/product_parser.py @@ -0,0 +1,294 @@ +""" +Product Parser +Handles CSV/JSON parsing with field detection and mapping +""" + +import logging +import csv +import json +import io +from pathlib import Path +from typing import Dict, List, Optional, Any +import pandas as pd +from schemas.product_schemas import FieldMapping + +logger = logging.getLogger(__name__) + + +class ProductParser: + """ + Parse product data from CSV/JSON/XLSX files. + + Handlers format detection, parsing, and column mapping normalization + to standardized product fields. + """ + + # Common field name variations + FIELD_MAPPINGS = { + 'name': ['name', 'title', 'product_name', 'product_title', 'item_name', 'item'], + 'description': ['description', 'desc', 'details', 'product_description', 'summary'], + 'category': ['category', 'categories', 'cat', 'product_category', 'type'], + 'price': ['price', 'cost', 'amount', 'product_price', 'cost_price', 'list_price'], + 'rating': ['rating', 'stars', 'star_rating', 'avg_rating', 'average_rating', 'score'], + 'review_count': ['review_count', 'reviews', 'num_reviews', 'review_number', 'total_reviews'], + 'image_url': ['image_url', 'image', 'img', 'image_link', 'picture', 'photo'], + 'brand': ['brand', 'manufacturer', 'maker', 'company', 'vendor'], + 'id': ['id', 'product_id', 'item_id', 'sku', 'asin', 'identifier'] + } + + def __init__(self): + """Initialize product parser""" + pass + + def detect_file_type(self, filename: str) -> str: + """ + Detect file type from filename extension. + + Args: + filename (str): Name of the file. + + Returns: + str: Detected type ('csv', 'json', 'xlsx', or 'unknown'). + """ + ext = Path(filename).suffix.lower() + if ext == '.csv': + return 'csv' + elif ext == '.json': + return 'json' + elif ext in ['.xlsx', '.xls']: + return 'xlsx' + else: + return 'unknown' + + def parse_csv(self, content: bytes, filename: str) -> tuple[List[Dict], List[str]]: + """ + Parse CSV file content. + + Args: + content (bytes): File content bytes. + filename (str): Original filename (used for logging). + + Returns: + tuple[List[Dict], List[str]]: Tuple containing: + - List of product dictionaries (rows) + - List of column names found + """ + try: + # Try to detect encoding + text = content.decode('utf-8') + except UnicodeDecodeError: + try: + text = content.decode('latin-1') + except UnicodeDecodeError: + text = content.decode('utf-8', errors='ignore') + + # Parse CSV + csv_reader = csv.DictReader(io.StringIO(text)) + columns = csv_reader.fieldnames or [] + + # Ensure columns is a list of strings + if columns: + columns = [str(col) if not isinstance(col, str) else col for col in columns] + else: + columns = [] + + products = [] + for row in csv_reader: + # Convert empty strings to None + clean_row = {k: (v if v and v.strip() else None) for k, v in row.items()} + products.append(clean_row) + + logger.info(f"Parsed CSV: {len(products)} products, {len(columns)} columns") + return products, columns + + def parse_json(self, content: bytes) -> tuple[List[Dict], List[str]]: + """ + Parse JSON file content. + + Handles simple lists, or wrapped objects (e.g. {'products': [...]}). + + Args: + content (bytes): File content bytes. + + Returns: + tuple[List[Dict], List[str]]: Tuple containing: + - List of product dictionaries + - List of unique field names found across all products + + Raises: + ValueError: If JSON structure is invalid/unsupported. + """ + try: + text = content.decode('utf-8') + data = json.loads(text) + except json.JSONDecodeError as e: + logger.error(f"JSON parse error: {e}") + raise ValueError(f"Invalid JSON format: {e}") + + # Handle different JSON structures + if isinstance(data, list): + products = data + elif isinstance(data, dict): + # Check if it's a wrapper object + if 'products' in data: + products = data['products'] + elif 'items' in data: + products = data['items'] + else: + # Single product + products = [data] + else: + raise ValueError("JSON must be an object or array") + + # Get all unique field names + all_fields = set() + for product in products: + if isinstance(product, dict): + all_fields.update(product.keys()) + + logger.info(f"Parsed JSON: {len(products)} products, {len(all_fields)} fields") + return products, list(all_fields) + + def parse_xlsx(self, content: bytes) -> tuple[List[Dict], List[str]]: + """ + Parse Excel (XLSX) file content. + + Args: + content (bytes): File content bytes. + + Returns: + tuple[List[Dict], List[str]]: Tuple containing: + - List of product dictionaries (rows) + - List of column names + + Raises: + ValueError: If parsing fails. + """ + try: + # Read Excel file + df = pd.read_excel(io.BytesIO(content)) + columns = df.columns.tolist() + + # Convert to list of dictionaries + products = df.replace({pd.NA: None, '': None}).to_dict('records') + + logger.info(f"Parsed XLSX: {len(products)} products, {len(columns)} columns") + return products, columns + except Exception as e: + logger.error(f"XLSX parse error: {e}") + raise ValueError(f"Failed to parse XLSX file: {e}") + + def detect_field_mapping(self, columns: List[str]) -> FieldMapping: + """ + Auto-detect field mapping from column names. + + Uses common variations (e.g., 'cost' -> 'price') to suggest mappings. + + Args: + columns (List[str]): List of column names from the file. + + Returns: + FieldMapping: Object containing suggested field-to-column mappings. + """ + mapping = FieldMapping() + + # Normalize column names for matching (handle tuples/strings) + normalized_columns = {} + for col in columns: + # Convert to string if it's a tuple or other type + col_str = str(col) if not isinstance(col, str) else col + normalized_key = col_str.lower().strip().replace(' ', '_').replace('-', '_') + normalized_columns[normalized_key] = col_str + + # Try to match each field + for field_name, variations in self.FIELD_MAPPINGS.items(): + for variation in variations: + normalized_var = variation.lower().strip().replace(' ', '_').replace('-', '_') + if normalized_var in normalized_columns: + # Found a match + setattr(mapping, field_name, normalized_columns[normalized_var]) + break + + logger.info(f"Detected field mapping: {mapping.model_dump(exclude_none=True)}") + return mapping + + def parse_file( + self, + content: bytes, + filename: str, + field_mapping: Optional[FieldMapping] = None + ) -> tuple[List[Dict], List[str], FieldMapping]: + """ + Parse product file and return products with field mapping. + + Main entry point that delegates to specific parsers based on file type. + + Args: + content (bytes): File content. + filename (str): Name of file. + field_mapping (Optional[FieldMapping]): Existing mapping to use, or None to auto-detect. + + Returns: + tuple: (products, columns, field_mapping) + + Raises: + ValueError: If file type unsupported or required fields missing. + """ + file_type = self.detect_file_type(filename) + + if file_type == 'csv': + products, columns = self.parse_csv(content, filename) + elif file_type == 'json': + products, columns = self.parse_json(content) + elif file_type == 'xlsx': + products, columns = self.parse_xlsx(content) + else: + raise ValueError(f"Unsupported file type: {file_type}") + + # Ensure columns are strings + columns = [str(col) if not isinstance(col, str) else col for col in columns] + + # Auto-detect mapping if not provided + if field_mapping is None: + field_mapping = self.detect_field_mapping(columns) + + # Validate required fields + if not field_mapping.name: + raise ValueError("Required field 'name' not found. Please provide field mapping.") + + logger.info(f"Parsed {len(products)} products from {filename}") + return products, columns, field_mapping + + def apply_field_mapping( + self, + products: List[Dict], + field_mapping: FieldMapping + ) -> List[Dict]: + """ + Apply field mapping to normalize product dictionaries. + + Renames keys in product dictionaries according to the mapping. + + Args: + products (List[Dict]): List of raw product dictionaries. + field_mapping (FieldMapping): Mapping configuration. + + Returns: + List[Dict]: List of normalized product dictionaries. + """ + normalized_products = [] + + for product in products: + normalized = {} + mapping_dict = field_mapping.model_dump(exclude_none=True) + + for target_field, source_field in mapping_dict.items(): + if source_field in product: + normalized[target_field] = product[source_field] + else: + normalized[target_field] = None + + normalized_products.append(normalized) + + return normalized_products + diff --git a/sample_solutions/HybridSearch/api/ingestion/services/product_processor.py b/sample_solutions/HybridSearch/api/ingestion/services/product_processor.py new file mode 100644 index 00000000..3c6b513f --- /dev/null +++ b/sample_solutions/HybridSearch/api/ingestion/services/product_processor.py @@ -0,0 +1,400 @@ +""" +Product Processor +Handles data validation, text preparation, and batch processing +""" + +import logging +import re +import uuid +from typing import Dict, List, Optional, Tuple +from datetime import datetime +from schemas.product_schemas import ProductCreate + +logger = logging.getLogger(__name__) + + +class ProductProcessor: + """ + Process and validate product data. + + Handles data cleaning, normalization (price, rating), validation, and + generation of embedding text fields. + """ + + def __init__(self, embedding_field_template: str = None): + """ + Initialize product processor. + + Args: + embedding_field_template (str, optional): Template for creating embedding text. + Default: "{name}. {description}. Category: {category}. Brand: {brand}" + """ + self.embedding_field_template = embedding_field_template or \ + "{name}. {description}. Category: {category}. Brand: {brand}" + + def normalize_price(self, price: any) -> Optional[float]: + """ + Normalize price from various formats. + + Handles strings with currency symbols ($, €, etc.), commas, and ISO codes. + + Args: + price (any): Price value (string, float, int, or None). + + Returns: + Optional[float]: Normalized price as float, or None if invalid. + """ + if price is None: + return None + + if isinstance(price, (int, float)): + return float(price) if price >= 0 else None + + if isinstance(price, str): + # Remove currency symbols and whitespace + price_str = price.strip() + if not price_str: + return None + + # Remove common currency symbols + price_str = re.sub(r'[$€£¥₹]', '', price_str) + + # Remove "USD", "EUR", etc. + price_str = re.sub(r'\b(USD|EUR|GBP|JPY|INR)\b', '', price_str, flags=re.IGNORECASE) + + # Remove commas and other formatting + price_str = price_str.replace(',', '').strip() + + # Extract number + match = re.search(r'[\d.]+', price_str) + if match: + try: + value = float(match.group()) + return value if value >= 0 else None + except ValueError: + return None + + return None + + def normalize_rating(self, rating: any) -> Optional[float]: + """ + Normalize rating to 0-5 scale. + + Handles string parsing and rescaling 0-10 ratings to 0-5. + + Args: + rating (any): Rating value (string, float, int, or None). + + Returns: + Optional[float]: Normalized rating (0-5) or None. + """ + if rating is None: + return None + + if isinstance(rating, (int, float)): + value = float(rating) + # If rating is > 5, assume it's out of 10 and scale down + if value > 5: + value = value / 2.0 + return value if 0 <= value <= 5 else None + + if isinstance(rating, str): + rating_str = rating.strip() + if not rating_str: + return None + + # Extract number + match = re.search(r'[\d.]+', rating_str) + if match: + try: + value = float(match.group()) + # If rating is > 5, assume it's out of 10 and scale down + if value > 5: + value = value / 2.0 + return value if 0 <= value <= 5 else None + except ValueError: + return None + + return None + + def normalize_review_count(self, count: any) -> Optional[int]: + """ + Normalize review count. + + Removes commas and parses integers. + + Args: + count (any): Review count value (string, int, or None). + + Returns: + Optional[int]: Normalized review count as int or None. + """ + if count is None: + return None + + if isinstance(count, int): + return count if count >= 0 else None + + if isinstance(count, str): + count_str = count.strip() + if not count_str: + return None + + # Remove commas and extract number + count_str = count_str.replace(',', '').strip() + match = re.search(r'\d+', count_str) + if match: + try: + value = int(match.group()) + return value if value >= 0 else None + except ValueError: + return None + + return None + + def clean_text(self, text: Optional[str], max_length: int = None) -> Optional[str]: + """ + Clean and normalize text. + + Removes HTML tags, normalizes whitespace, and truncates if necessary. + + Args: + text (Optional[str]): Text to clean. + max_length (int, optional): Maximum length (truncate if longer). + + Returns: + Optional[str]: Cleaned text or None if empty. + """ + if not text: + return None + + # Remove HTML tags + text = re.sub(r'<[^>]+>', '', text) + + # Normalize whitespace + text = ' '.join(text.split()) + + # Handle special characters + text = text.strip() + + # Truncate if needed + if max_length and len(text) > max_length: + text = text[:max_length-3] + '...' + + return text if text else None + + def create_embedding_text( + self, + name: str, + description: Optional[str] = None, + category: Optional[str] = None, + brand: Optional[str] = None + ) -> str: + """ + Create embedding text from product fields + + Args: + name: Product name + description: Product description + category: Product category + brand: Product brand + + Returns: + Concatenated text for embedding + """ + # Clean and prepare fields + name = self.clean_text(name) or "" + description = self.clean_text(description) or "" + category = self.clean_text(category) or "" + brand = self.clean_text(brand) or "" + + # Build embedding text using template + embedding_text = self.embedding_field_template.format( + name=name, + description=description, + category=category, + brand=brand + ) + + # Clean up the result + embedding_text = ' '.join(embedding_text.split()) + + # Ensure we have at least the name + if not embedding_text: + embedding_text = name + + return embedding_text + + def validate_product(self, product: Dict) -> Tuple[bool, List[str]]: + """ + Validate product data. + + Checks required fields and ensures numeric fields are within valid ranges. + + Args: + product (Dict): Product dictionary. + + Returns: + Tuple[bool, List[str]]: processing status (True/False) and list of error messages. + """ + errors = [] + + # Check required fields + if not product.get('name'): + errors.append("Product name is required") + + # Validate price + price = product.get('price') + if price is not None: + normalized = self.normalize_price(price) + if normalized is None: + errors.append(f"Invalid price format: {price}") + elif normalized < 0: + errors.append(f"Price cannot be negative: {normalized}") + + # Validate rating + rating = product.get('rating') + if rating is not None: + normalized = self.normalize_rating(rating) + if normalized is None: + errors.append(f"Invalid rating format: {rating}") + elif not (0 <= normalized <= 5): + errors.append(f"Rating must be between 0 and 5: {normalized}") + + # Validate review count + review_count = product.get('review_count') + if review_count is not None: + normalized = self.normalize_review_count(review_count) + if normalized is None: + errors.append(f"Invalid review count format: {review_count}") + elif normalized < 0: + errors.append(f"Review count cannot be negative: {normalized}") + + return len(errors) == 0, errors + + def process_product(self, product: Dict, generate_id: bool = True) -> Dict: + """ + Process and normalize a single product. + + Applies cleaning, normalization, and generates embedding text. + + Args: + product (Dict): Raw product dictionary. + generate_id (bool): Whether to generate UUID if ID is missing. + + Returns: + Dict: Normalized product dictionary ready for ingestion. + """ + # Generate ID if missing + if not product.get('id') and generate_id: + product['id'] = f"prod_{uuid.uuid4().hex[:12]}" + + # Normalize fields + processed = { + 'id': product.get('id'), + 'name': self.clean_text(product.get('name')), + 'description': self.clean_text(product.get('description')), + 'category': self.clean_text(product.get('category')), + 'price': self.normalize_price(product.get('price')), + 'rating': self.normalize_rating(product.get('rating')), + 'review_count': self.normalize_review_count(product.get('review_count')), + 'image_url': self.clean_text(product.get('image_url')), + 'brand': self.clean_text(product.get('brand')) + } + + # Create embedding text + processed['embedding_text'] = self.create_embedding_text( + name=processed['name'] or "", + description=processed['description'], + category=processed['category'], + brand=processed['brand'] + ) + + return processed + + def process_batch( + self, + products: List[Dict], + batch_size: int = 100, + skip_invalid: bool = True + ) -> Tuple[List[Dict], List[Dict]]: + """ + Process products in batches. + + Args: + products (List[Dict]): List of raw product dictionaries. + batch_size (int): Number of products per batch (unused in logic but kept for interface compatibility). + skip_invalid (bool): Whether to skip invalid products or raise error. + + Returns: + Tuple[List[Dict], List[Dict]]: (processed_products, invalid_products_with_errors). + """ + processed = [] + invalid = [] + + for i, product in enumerate(products): + try: + # Process product + processed_product = self.process_product(product) + + # Validate + is_valid, errors = self.validate_product(processed_product) + + if is_valid: + processed.append(processed_product) + else: + if skip_invalid: + logger.warning(f"Product {i+1} invalid: {', '.join(errors)}") + invalid.append({ + 'product': product, + 'errors': errors + }) + else: + raise ValueError(f"Product {i+1} invalid: {', '.join(errors)}") + + except Exception as e: + logger.error(f"Error processing product {i+1}: {e}") + if skip_invalid: + invalid.append({ + 'product': product, + 'errors': [str(e)] + }) + else: + raise + + logger.info(f"Processed {len(processed)} products, {len(invalid)} invalid") + return processed, invalid + + def detect_duplicates( + self, + products: List[Dict], + similarity_threshold: float = 0.9 + ) -> List[Tuple[int, int]]: + """ + Detect duplicate products by name similarity + + Args: + products: List of processed products + similarity_threshold: Similarity threshold (0-1) + + Returns: + List of (index1, index2) tuples for duplicate pairs + """ + # Simple duplicate detection based on exact name match + # In production, could use more sophisticated similarity + name_to_indices = {} + duplicates = [] + + for idx, product in enumerate(products): + name = product.get('name', '').lower().strip() + if name: + if name in name_to_indices: + # Found duplicate + for prev_idx in name_to_indices[name]: + duplicates.append((prev_idx, idx)) + name_to_indices[name].append(idx) + else: + name_to_indices[name] = [idx] + + return duplicates + diff --git a/sample_solutions/HybridSearch/api/llm/Dockerfile b/sample_solutions/HybridSearch/api/llm/Dockerfile new file mode 100644 index 00000000..90575219 --- /dev/null +++ b/sample_solutions/HybridSearch/api/llm/Dockerfile @@ -0,0 +1,33 @@ +# LLM Service Dockerfile +FROM python:3.10-slim + +# Set working directory +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements +COPY requirements.txt . + +# Install Python dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code and create non-root user +COPY . . +RUN useradd -m -u 1000 appuser && \ + chown -R appuser:appuser /app +USER appuser + +# Expose port +EXPOSE 8003 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8003/health || exit 1 + +# Run the application +CMD ["python", "main.py"] + diff --git a/sample_solutions/HybridSearch/api/llm/api_client.py b/sample_solutions/HybridSearch/api/llm/api_client.py new file mode 100644 index 00000000..72a218c6 --- /dev/null +++ b/sample_solutions/HybridSearch/api/llm/api_client.py @@ -0,0 +1,100 @@ +""" +API Client for GenAI Gateway authentication and enterprise API calls +""" + +import httpx +import logging +import re +from openai import OpenAI +from config import settings + +logger = logging.getLogger(__name__) + + +def clean_url(url: str) -> str: + """ + Remove invisible characters and whitespace from URL. + + Args: + url (str): The URL string to clean. + + Returns: + str: The cleaned URL string. + """ + if not url: + return url + # Remove non-printable characters, whitespace, and specific zero-width chars + return re.sub(r'[\x00-\x1f\x7f-\x9f\s\u200b\u2060\ufeff]+', '', url) + + +class APIClient: + """ + Client for handling GenAI Gateway authentication and API calls. + + This client manages API calls to GenAI Gateway or APISIX Gateway endpoints, + including LLM inference generation. + """ + + def __init__(self): + # Use per-model endpoint if set (APISIX), otherwise fall back to GenAI Gateway URL + self.use_apisix = bool(settings.llm_api_endpoint) + base_url = settings.llm_api_endpoint or settings.genai_gateway_url + self.base_url = clean_url(base_url).rstrip('/') if base_url else None + self.token = settings.genai_api_key + # LLM is always vLLM (even on Gaudi), so only drop /v1 for GenAI Gateway + Gaudi. + # When APISIX is in use (LLM_API_ENDPOINT set), always keep /v1. + self.use_tei = settings.inference_backend.lower() == "tei" and not self.use_apisix + self.http_client = httpx.Client(verify=settings.verify_ssl, timeout=120.0) if self.token else None + + if not self.token or not self.base_url: + raise ValueError("GenAI Gateway configuration missing. Check GENAI_GATEWAY_URL and GENAI_API_KEY.") + + logger.info(f"Using gateway at {self.base_url} (backend: {settings.inference_backend})") + + def get_inference_client(self, endpoint: str = None): + """ + Get OpenAI-style client for inference/completions. + + Args: + endpoint (str, optional): Specific model endpoint path (ignored for GenAI Gateway). + + Returns: + OpenAI: An instantiated OpenAI client configured for the GenAI Gateway. + """ + # TEI (Gaudi) serves at /chat/completions; vLLM (Xeon) serves at /v1/chat/completions + client_base_url = self.base_url if self.use_tei else f"{self.base_url}/v1" + logger.info(f"Creating OpenAI client with base_url: {client_base_url}") + + http_client = httpx.Client(verify=settings.verify_ssl, timeout=120.0) + + return OpenAI( + api_key=self.token, + base_url=client_base_url, + http_client=http_client + ) + + def is_authenticated(self) -> bool: + """ + Check if client is authenticated. + + Returns: + bool: True if authenticated, False otherwise. + """ + return bool(self.token and self.http_client) + + +# Global instance +_api_client = None + + +def get_api_client(): + """ + Get or create global API client instance. + + Returns: + APIClient: The global singleton instance of APIClient. + """ + global _api_client + if _api_client is None: + _api_client = APIClient() + return _api_client diff --git a/sample_solutions/HybridSearch/api/llm/clean_monologue.py b/sample_solutions/HybridSearch/api/llm/clean_monologue.py new file mode 100644 index 00000000..81bab3b0 --- /dev/null +++ b/sample_solutions/HybridSearch/api/llm/clean_monologue.py @@ -0,0 +1,68 @@ +""" +Helper function to clean internal monologue from LLM responses +""" +import re + + +def clean_internal_monologue(text: str) -> str: + """ + Remove internal thinking/monologue from LLM response. + + Handles removal of: + 1. ... tags (typical of Qwen/Reasoning models) + 2. Internal monologue patterns at the start of the response + + Args: + text (str): Raw LLM response text. + + Returns: + str: Cleaned text with internal monologue removed. + """ + if not text: + return text + + # Step 1: Remove ... blocks (Qwen's internal thinking) + text = re.sub(r'.*?', '', text, flags=re.DOTALL | re.IGNORECASE) + + # Step 2: Remove any remaining thinking patterns at the start + paragraphs = text.split('\n\n') + + if len(paragraphs) <= 1: + return text.strip() + + # Patterns that indicate internal thinking (case-insensitive) + thinking_indicators = [ + r'\bokay,?\s+let\'?s\b', + r'\bfirst,?\s+i\b', + r'\bi\s+need\s+to\b', + r'\bi\s+should\b', + r'\bstarting\s+with\b', + r'\bputting\s+(this|it)\s+together\b', + r'\bthe\s+user\s+(wants|is\s+asking)\b', + r'\blooking\s+at\b', + r'\bgoing\s+through\b', + ] + + # Find first paragraph that's NOT thinking + cleaned_paragraphs = [] + found_content = False + + for para in paragraphs: + para_stripped = para.strip() + if not para_stripped: + continue + + para_lower = para_stripped.lower() + + # Check if this is thinking + has_thinking = any(re.search(pattern, para_lower, re.IGNORECASE) for pattern in thinking_indicators) + + if not has_thinking or found_content: + cleaned_paragraphs.append(para) + found_content = True + + # If we filtered everything, return original (safety) + if not cleaned_paragraphs: + return text.strip() + + return '\n\n'.join(cleaned_paragraphs).strip() \ No newline at end of file diff --git a/sample_solutions/HybridSearch/api/llm/config.py b/sample_solutions/HybridSearch/api/llm/config.py new file mode 100644 index 00000000..828efbb3 --- /dev/null +++ b/sample_solutions/HybridSearch/api/llm/config.py @@ -0,0 +1,103 @@ +""" +LLM Service Configuration +Manages environment variables and service settings +Supports GenAI Gateway and APISIX Gateway +""" + +from pydantic_settings import BaseSettings +from typing import Optional +from pathlib import Path + + +class Settings(BaseSettings): + """ + Service configuration with environment variable loading. + + This class defines the configuration settings for the LLM Service, + including deployment phase and GenAI Gateway/APISIX Gateway settings. + """ + + # Deployment Phase + deployment_phase: str = "development" + + # GenAI Gateway Configuration + # Supports multiple deployment patterns: + # - GenAI Gateway: Provide GENAI_GATEWAY_URL and GENAI_API_KEY + # - APISIX Gateway: Provide GENAI_GATEWAY_URL and GENAI_API_KEY + genai_gateway_url: Optional[str] = None + genai_api_key: Optional[str] = None + + # Per-model endpoint URL (required for APISIX, optional for GenAI Gateway) + llm_api_endpoint: Optional[str] = None + + # Inference backend type: "vllm" (Xeon, default) or "tei" (Gaudi) + # TEI does not use the /v1 path prefix; vLLM does + inference_backend: str = "vllm" + + # Model Configuration (for Enterprise) + llm_model_endpoint: str = "Qwen/Qwen3-4B-Instruct-2507" + llm_model_name: str = "Qwen/Qwen3-4B-Instruct-2507" + + # Dual Model Configuration + inference_model_endpoint_simple: str = "Qwen3-4B-Instruct-2507-vllmcpu" + inference_model_name_simple: str = "Qwen/Qwen3-4B-Instruct-2507" + + inference_model_endpoint_complex: str = "Qwen3-4B-Instruct-2507-vllmcpu" + inference_model_name_complex: str = "Qwen/Qwen3-4B-Instruct-2507" + + # Service Configuration + llm_port: int = 8003 + llm_host: str = "0.0.0.0" # nosec B104 - Binding to all interfaces is intentional for Docker container + + # LLM Parameters + max_tokens_simple: int = 512 + max_tokens_complex: int = 512 + temperature_simple: float = 0.1 + temperature_complex: float = 0.6 + + # SSL Verification Settings + verify_ssl: bool = True + + # Product Catalog Settings + system_mode: str = "document" # "document" or "product" + + # Logging + log_level: str = "INFO" + + class Config: + """Pydantic configuration.""" + # Look for .env file in the hybrid-search root directory + env_file = Path(__file__).parent.parent.parent / ".env" + case_sensitive = False + extra = "ignore" # Ignore extra fields in .env file + + def is_enterprise_configured(self) -> bool: + """ + Check if GenAI Gateway is configured. + + Returns: + bool: True if genai_gateway_url and genai_api_key are present. + """ + return bool(self.genai_gateway_url and self.genai_api_key) + + def validate_config(self): + """ + Validate that GenAI Gateway is configured. + + This service requires GenAI Gateway or APISIX Gateway authentication. + + Raises: + ValueError: If required configuration is missing. + """ + if not self.is_enterprise_configured(): + raise ValueError( + "GenAI Gateway configuration missing. " + "Must provide GENAI_GATEWAY_URL and GENAI_API_KEY in .env file." + ) + + +# Global settings instance +settings = Settings() + +# Validate configuration on import +settings.validate_config() diff --git a/sample_solutions/HybridSearch/api/llm/main.py b/sample_solutions/HybridSearch/api/llm/main.py new file mode 100644 index 00000000..6516eeca --- /dev/null +++ b/sample_solutions/HybridSearch/api/llm/main.py @@ -0,0 +1,936 @@ +""" +LLM Service - OpenAI API Wrapper for Question Answering +Handles dual-model routing for simple and complex queries +""" + +import logging +import time +import re +from pathlib import Path +from typing import List, Optional, Dict, Any +from fastapi import FastAPI, HTTPException, status +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, Field +from openai import OpenAI, OpenAIError, RateLimitError, APIConnectionError, APITimeoutError +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, + before_sleep_log +) +from config import settings +from services.response_formatter import ResponseFormatter +from prompts.product_prompts import ProductPrompts +from clean_monologue import clean_internal_monologue + +# Configure logging +logging.basicConfig( + level=settings.log_level, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# Initialize FastAPI app +app = FastAPI( + title="LLM Service", + description="OpenAI-powered question answering service with dual-model routing", + version="1.0.0" +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Initialize GenAI Gateway API client +try: + from api_client import get_api_client + + api_client = get_api_client() + + if not api_client.is_authenticated(): + raise RuntimeError("GenAI Gateway authentication failed - cannot start service without API access") + + client = api_client.get_inference_client() + logger.info("✓ GenAI Gateway API client initialized successfully") + logger.info(f" Simple model: {settings.inference_model_name_simple}") + logger.info(f" Complex model: {settings.inference_model_name_complex}") + logger.info(f" Authentication: GenAI Gateway API Key") + logger.info(f" Base URL: {settings.genai_gateway_url}") + +except Exception as e: + logger.error(f"Failed to initialize GenAI Gateway API client: {e}") + logger.error("Service requires GenAI Gateway authentication and endpoints") + raise RuntimeError(f"GenAI Gateway API initialization failed: {e}") from e + + +# Load prompt templates +PROMPTS_DIR = Path(__file__).parent / "prompts" +SIMPLE_QA_PROMPT = (PROMPTS_DIR / "simple_qa.txt").read_text() +COMPLEX_QA_PROMPT = (PROMPTS_DIR / "complex_qa.txt").read_text() + +# Initialize product response formatter +response_formatter = ResponseFormatter() +product_prompts = ProductPrompts() + + +# Request/Response Models +class RetrievalChunk(BaseModel): + """ + Retrieved document chunk model. + + Attributes: + chunk_id: Unique identifier for the chunk. + document_id: ID of the parent document. + text: Text content of the chunk. + page_number: Page number (optional). + score: Relevance score (similarity/ranking). + metadata: Additional metadata dictionary. + """ + chunk_id: str + document_id: str + text: str + page_number: Optional[int] = None + score: float + metadata: Dict[str, Any] = {} + + +class Citation(BaseModel): + """ + Citation model for generated answers. + + Attributes: + document_id: ID of the cited document. + page_number: Page number in the document (optional). + chunk_id: ID of the specific chunk used. + confidence_score: Relevance score of the chunk. + relevant_text_snippet: Snippet of text justifying the citation. + """ + document_id: str + page_number: Optional[int] = None + chunk_id: str + confidence_score: float + relevant_text_snippet: str + + +class LLMRequest(BaseModel): + """Request model for LLM generation""" + query: str = Field(..., description="User query") + context_chunks: List[RetrievalChunk] = Field( + ..., + description="Retrieved context chunks" + ) + model_type: str = Field( + "auto", + description="Model type: 'simple', 'complex', or 'auto'" + ) + max_tokens: Optional[int] = Field( + None, + description="Maximum tokens to generate (overrides defaults)" + ) + temperature: Optional[float] = Field( + None, + description="Temperature for generation (overrides defaults)" + ) + include_citations: bool = Field( + True, + description="Whether to extract citations from response" + ) + + class Config: + json_schema_extra = { + "example": { + "query": "What are the main findings?", + "context_chunks": [ + { + "chunk_id": "chunk_1", + "document_id": "doc_123", + "text": "The study found significant improvements...", + "page_number": 5, + "score": 0.92, + "metadata": {} + } + ], + "model_type": "auto", + "include_citations": True + } + } + + +class LLMResponse(BaseModel): + """ + Response model for LLM generation. + + Attributes: + answer: Generated answer text. + citations: List of citations supporting the answer. + model_used: Name of the model used for generation. + query_type: Detected complexity of the query ('simple' or 'complex'). + generation_time_ms: Time taken for generation in milliseconds. + token_count: Total tokens used (if available). + """ + answer: str + citations: List[Citation] + model_used: str + query_type: str + generation_time_ms: float + token_count: Optional[int] = None + + +class HealthResponse(BaseModel): + """Health check response""" + status: str + service: str + deployment_phase: str + models: Dict[str, str] + + +class ModelInfoResponse(BaseModel): + """Model information response""" + simple_model: str + complex_model: str + max_tokens_simple: int + max_tokens_complex: int + temperature_simple: float + temperature_complex: float + + +# Helper Functions +def format_context(chunks: List[RetrievalChunk]) -> str: + """ + Format retrieved chunks into a context string for the LLM. + + Args: + chunks (List[RetrievalChunk]): List of retrieved document chunks. + + Returns: + str: Formatted context string with document IDs and page numbers. + """ + context_parts = [] + for idx, chunk in enumerate(chunks, 1): + page_info = f" [Page {chunk.page_number}]" if chunk.page_number else "" + context_parts.append( + f"[{idx}] Document: {chunk.document_id}{page_info}\n{chunk.text}\n" + ) + return "\n".join(context_parts) + + +def extract_citations( + answer: str, + context_chunks: List[RetrievalChunk] +) -> List[Citation]: + """ + Extract citations from the answer text. + + Parses citations in formats like [Page X], [Page X-Y], [Doc ID, Page X] + and maps them back to the original context chunks. + + Args: + answer (str): Generated answer text. + context_chunks (List[RetrievalChunk]): Original context chunks used for generation. + + Returns: + List[Citation]: List of extracted citation objects. + """ + citations = [] + + # Pattern to match citations: [Page X], [Page X-Y], [Doc ID, Page X] + citation_patterns = [ + r'\[Page (\d+)\]', + r'\[Page (\d+)-(\d+)\]', + r'\[([^,]+), Page (\d+)\]' + ] + + # Track which chunks were cited + cited_chunks = set() + + for i, pattern in enumerate(citation_patterns): + matches = re.finditer(pattern, answer) + for match in matches: + try: + # Extract page number based on pattern type + if i == 0: # [Page X] + page_num = int(match.group(1)) + elif i == 1: # [Page X-Y] + page_num = int(match.group(1)) # Use start page + elif i == 2: # [Doc ID, Page X] + page_num = int(match.group(2)) + else: + continue + + # Find matching chunks + for chunk in context_chunks: + if chunk.page_number == page_num and chunk.chunk_id not in cited_chunks: + # Extract snippet around the citation + snippet = chunk.text[:200] + "..." if len(chunk.text) > 200 else chunk.text + + citations.append(Citation( + document_id=chunk.document_id, + page_number=chunk.page_number, + chunk_id=chunk.chunk_id, + confidence_score=chunk.score, + relevant_text_snippet=snippet + )) + cited_chunks.add(chunk.chunk_id) + break + except (ValueError, IndexError, AttributeError): + continue + + # If no explicit citations found, use top-3 chunks as implicit citations + if not citations and context_chunks: + for chunk in context_chunks[:3]: + snippet = chunk.text[:200] + "..." if len(chunk.text) > 200 else chunk.text + citations.append(Citation( + document_id=chunk.document_id, + page_number=chunk.page_number, + chunk_id=chunk.chunk_id, + confidence_score=chunk.score, + relevant_text_snippet=snippet + )) + + return citations + + +def detect_query_complexity(query: str) -> str: + """ + Detect if query is simple or complex based on keywords and heuristics. + + Simple queries (fact retrieval) use smaller models. + Complex queries (analysis, comparison) use reasoning models. + + Args: + query (str): User query string. + + Returns: + str: 'simple' or 'complex'. + """ + query_lower = query.lower() + + # Complex indicators + complex_indicators = [ + "compare", "analyze", "explain why", "relationship between", + "impact of", "evaluate", "synthesize", "how does", "affect", + "differences between", "similarities", "trend", "pattern", + "correlation", "cause", "effect" + ] + + # Simple indicators + simple_indicators = [ + "what is", "who is", "when did", "where is", + "define", "list", "name", "how many" + ] + + # Check for complex indicators + if any(indicator in query_lower for indicator in complex_indicators): + return "complex" + + # Check for simple indicators + if any(indicator in query_lower for indicator in simple_indicators): + return "simple" + + # Default heuristics + word_count = len(query.split()) + question_count = query.count("?") + + if word_count > 15 or question_count > 1: + return "complex" + + return "simple" + + +def clean_internal_monologue(text: str) -> str: + """ + Remove internal thinking/monologue from LLM response. + Handles: + 1. ... tags (Qwen models) + 2. Internal monologue patterns at the start + + Args: + text: Raw LLM response + + Returns: + Cleaned text with internal monologue removed + """ + if not text: + return text + + # STEP 1: Remove ... blocks (Qwen's internal thinking) + text = re.sub(r'.*?', '', text, flags=re.DOTALL | re.IGNORECASE) + text = text.strip() + + if not text: + return text + + # STEP 2: Remove paragraph-level thinking patterns + paragraphs = text.split('\n\n') + + # Patterns that indicate internal thinking (case-insensitive) + thinking_patterns = [ + r'^okay,?\s+let\'?s', + r'^first,?\s+i\s', + r'^i\s+need\s+to', + r'^i\s+should', + r'^i\s+will', + r'^i\'ll', + r'^starting\s+with', + r'^putting\s+this\s+together', + r'^the\s+user\s+(wants|is\s+asking)', + r'^looking\s+at', + r'^going\s+through', + r'^analyzing', + r'^from\s+what\s+i\s+can\s+see', + ] + + cleaned_paragraphs = [] + skip_mode = False + + for para in paragraphs: + para_stripped = para.strip() + if not para_stripped: + continue + + para_lower = para_stripped.lower() + + # Check if this paragraph starts with thinking patterns + is_thinking = any(re.match(pattern, para_lower) for pattern in thinking_patterns) + + # Check for first-person pronouns at the start + has_first_person_start = re.match(r'^(i\s|my\s|we\s|our\s)', para_lower) + + # If we find thinking patterns, skip until we find actual content + if is_thinking or (has_first_person_start and len(para_stripped.split()) < 50): + skip_mode = True + continue + + # Look for section headers or structured content (likely the actual answer) + if re.match(r'^#+\s+', para_stripped) or re.match(r'^\d+\.', para_stripped) or re.match(r'^[A-Z][^.!?]*:', para_stripped): + skip_mode = False + + # If we're past the thinking phase, keep the content + if not skip_mode: + cleaned_paragraphs.append(para) + + # If we filtered everything out, return the original (safety fallback) + if not cleaned_paragraphs: + # Try to find the first paragraph that looks like actual content + for para in paragraphs: + para_stripped = para.strip() + if len(para_stripped) > 100 and not any(re.match(pattern, para_stripped.lower()) for pattern in thinking_patterns): + cleaned_paragraphs.append(para) + break + + # If still nothing, return original + if not cleaned_paragraphs: + return text + + return '\n\n'.join(cleaned_paragraphs).strip() + + +# Retry Configuration for OpenAI API +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)), + before_sleep=before_sleep_log(logger, logging.WARNING), + reraise=True +) +def _call_chat_completion( + client_instance: OpenAI, + model: str, + prompt: str, + max_tokens: int, + temperature: float +): + """ + Call chat completion API with retry logic + + Args: + client_instance: OpenAI client instance (or compatible) + model: Model name + prompt: User prompt + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + + Returns: + Chat completion response + + Raises: + OpenAIError: If all retries fail + """ + try: + return client_instance.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": "You are a helpful document analysis assistant. You must output ONLY the final answer. Do not include any internal monologue, thinking process, or self-correction. Start your response directly with the answer."}, + {"role": "user", "content": prompt} + ], + max_tokens=max_tokens, + temperature=temperature + ) + except (RateLimitError, APIConnectionError, APITimeoutError) as e: + logger.warning(f"API error (will retry): {type(e).__name__}: {e}") + raise + except OpenAIError as e: + # Don't retry on other errors (invalid request, etc.) + logger.error(f"API error (non-retryable): {e}") + raise + + +# API Endpoints +@app.post( + "/api/v1/llm/generate", + response_model=LLMResponse, + status_code=status.HTTP_200_OK, + summary="Generate answer for query", + description="Generate answer using appropriate LLM based on query complexity" +) +async def generate_answer(request: LLMRequest): + """ + Generate answer for the query using retrieved context. + + Orchestrates the generation process: + 1. Determines query complexity (if set to auto) + 2. Selects appropriate model (simple vs complex) + 3. Formats context chunks + 4. Calls LLM (OpenAI or Enterprise) + 5. Cleans response and extracts citations + + Args: + request (LLMRequest): Request object containing query, context chunks, and parameters. + + Returns: + LLMResponse: Generated answer with metadata and citations. + + Raises: + HTTPException: If LLM API call fails. + """ + try: + start_time = time.time() + + # Determine query type + if request.model_type == "auto": + query_type = detect_query_complexity(request.query) + else: + query_type = request.model_type + + # Select model and parameters + current_client = client # Default to global client + + if settings.is_enterprise_configured(): + # Enterprise API with dual model support + if query_type == "simple": + model = settings.inference_model_name_simple + endpoint = settings.inference_model_endpoint_simple + max_tokens = request.max_tokens or settings.max_tokens_simple + temperature = request.temperature or settings.temperature_simple + prompt_template = SIMPLE_QA_PROMPT + else: + model = settings.inference_model_name_complex + endpoint = settings.inference_model_endpoint_complex + max_tokens = request.max_tokens or settings.max_tokens_complex + temperature = request.temperature or settings.temperature_complex + prompt_template = COMPLEX_QA_PROMPT + + # Get specific client for the endpoint + from api_client import get_api_client + api_client_inst = get_api_client() + current_client = api_client_inst.get_inference_client(endpoint=endpoint) + + else: + # Fallback (should not be reached if config validation works) + logger.warning("Enterprise config missing, using simple model default") + model = settings.inference_model_name_simple + max_tokens = request.max_tokens or settings.max_tokens_simple + temperature = request.temperature or settings.temperature_simple + prompt_template = SIMPLE_QA_PROMPT + + logger.info( + f"Generating answer using {model} " + f"(query_type={query_type}, chunks={len(request.context_chunks)})" + ) + + # Format context + context = format_context(request.context_chunks) + + # Build prompt + prompt = prompt_template.format(context=context, query=request.query) + + # Call API with retry logic + response = _call_chat_completion( + client_instance=current_client, + model=model, + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature + ) + + # Extract answer + answer = response.choices[0].message.content + token_count = response.usage.total_tokens if response.usage else None + + # Log raw LLM output for debugging + logger.info(f"Raw LLM output (first 500 chars): {answer[:500]}...") + logger.info(f"Total LLM output length: {len(answer)} characters") + + + cleaned_answer = clean_internal_monologue(answer) + logger.info(f"Cleaned output (first 500 chars): {cleaned_answer[:500]}...") + logger.info(f"Cleaning removed {len(answer) - len(cleaned_answer)} characters") + answer = cleaned_answer + + # Extract citations + citations = [] + if request.include_citations: + citations = extract_citations(answer, request.context_chunks) + + # Calculate processing time + processing_time = (time.time() - start_time) * 1000 + + logger.info( + f"Answer generated in {processing_time:.2f}ms " + f"(tokens={token_count}, citations={len(citations)})" + ) + + return LLMResponse( + answer=answer, + citations=citations, + model_used=model, + query_type=query_type, + generation_time_ms=round(processing_time, 2), + token_count=token_count + ) + + except OpenAIError as e: + logger.error(f"OpenAI API error: {e}") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=f"OpenAI API error: {str(e)}" + ) + except Exception as e: + logger.error(f"Unexpected error: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Internal server error: {str(e)}" + ) + + +@app.post( + "/api/v1/llm/generate/simple", + response_model=LLMResponse, + status_code=status.HTTP_200_OK, + summary="Generate answer using simple model", + description="Force use of simple model (gpt-4o-mini) for factual questions" +) +async def generate_simple(request: LLMRequest): + """ + Generate answer using simple model. + + Forces the use of the 'simple' model configuration (optimized for speed and fact retrieval), + bypassing complexity detection. + + Args: + request (LLMRequest): Request object. + + Returns: + LLMResponse: Generated answer. + """ + request.model_type = "simple" + return await generate_answer(request) + + +@app.post( + "/api/v1/llm/generate/complex", + response_model=LLMResponse, + status_code=status.HTTP_200_OK, + summary="Generate answer using complex model", + description="Force use of complex model (gpt-4-turbo) for analytical questions" +) +async def generate_complex(request: LLMRequest): + """ + Generate answer using complex model. + + Forces the use of the 'complex' model configuration (reasoning/analysis focused), + bypassing complexity detection. + + Args: + request (LLMRequest): Request object. + + Returns: + LLMResponse: Generated answer. + """ + request.model_type = "complex" + return await generate_answer(request) + + +@app.get( + "/health", + response_model=HealthResponse, + status_code=status.HTTP_200_OK, + summary="Health check" +) +async def health_check(): + """ + Health check endpoint. + + Returns: + HealthResponse: Status of the service, current deployment phase, + and configured models. + """ + if settings.is_enterprise_configured(): + models_info = { + "provider": "GenAI Gateway", + "simple_model": settings.inference_model_name_simple, + "complex_model": settings.inference_model_name_complex + } + else: + models_info = { + "provider": "GenAI Gateway (Not Configured)", + "simple": "N/A", + "complex": "N/A" + } + + return HealthResponse( + status="healthy", + service="llm", + deployment_phase=settings.deployment_phase, + models=models_info + ) + + +@app.get( + "/api/v1/llm/models/info", + response_model=ModelInfoResponse, + status_code=status.HTTP_200_OK, + summary="Get model information" +) +async def get_model_info(): + """ + Get LLM model information. + + Returns: + ModelInfoResponse: Configuration details of currently active models + (simple/complex names, token limits, temperature). + """ + return ModelInfoResponse( + simple_model=settings.inference_model_name_simple, + complex_model=settings.inference_model_name_complex, + max_tokens_simple=settings.max_tokens_simple, + max_tokens_complex=settings.max_tokens_complex, + temperature_simple=settings.temperature_simple, + temperature_complex=settings.temperature_complex + ) + + +# Product Endpoints +class ProductRecommendationRequest(BaseModel): + """ + Request model for product recommendation. + + Attributes: + query: User's original query. + products: List of retrieved product dictionaries to analyze. + intent: Detected intent type. + filters: Applied filters (optional). + mode: Response mode ('quick' or 'explained'). + """ + query: str = Field(..., description="User query") + products: List[Dict[str, Any]] = Field(..., description="List of products") + intent: str = Field(..., description="Query intent type") + filters: Optional[Dict[str, Any]] = Field(None, description="Applied filters") + mode: str = Field("explained", description="Response mode: 'quick' or 'explained'") + + +class ProductRecommendationResponse(BaseModel): + """ + Response model for product recommendation. + + Attributes: + recommendation: Generated recommendation text/HTML. + mode: Mode used for generation ('quick' or 'explained'). + products_count: Number of products analyzed. + """ + recommendation: str + mode: str + products_count: int + + +@app.post( + "/api/v1/llm/generate/product-recommendation", + response_model=ProductRecommendationResponse, + status_code=status.HTTP_200_OK, + summary="Generate product recommendation", + description="Generate personalized product recommendations" +) +async def generate_product_recommendation(request: ProductRecommendationRequest): + """ + Generate product recommendation. + + Uses either a template-based 'quick' mode for standard listings or an + LLM-based 'explained' mode for in-depth advice, depending on intent + and result count. + + Args: + request (ProductRecommendationRequest): Request data. + + Returns: + ProductRecommendationResponse: Generated recommendation content. + + Raises: + HTTPException: If generation fails. + """ + try: + # Use formatter to determine mode and generate response + use_quick = response_formatter.should_use_quick_mode( + intent=request.intent, + product_count=len(request.products), + has_filters=bool(request.filters) + ) + + mode = "quick" if use_quick or request.mode == "quick" else "explained" + + if mode == "quick": + # Template-based response + recommendation = response_formatter.format_response( + query=request.query, + products=request.products, + intent=request.intent, + filters=request.filters, + mode="quick" + ) + else: + # LLM-generated response + prompt = response_formatter.format_response( + query=request.query, + products=request.products, + intent=request.intent, + filters=request.filters, + mode="explained" + ) + + # Determine model and client + if settings.is_enterprise_configured(): + model = settings.inference_model_name_simple + endpoint = settings.inference_model_endpoint_simple + from api_client import get_api_client + api_client_inst = get_api_client() + current_client = api_client_inst.get_inference_client(endpoint=endpoint) + else: + model = settings.inference_model_name_simple # Default fallback + current_client = client + + # Call LLM + response = current_client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": "You are a helpful shopping assistant. You must output ONLY the final recommendation. Do not include any internal monologue or thinking process."}, + {"role": "user", "content": prompt} + ], + max_tokens=settings.max_tokens_simple, + temperature=settings.temperature_simple + ) + + recommendation = response.choices[0].message.content + + return ProductRecommendationResponse( + recommendation=recommendation, + mode=mode, + products_count=len(request.products) + ) + + except Exception as e: + logger.error(f"Error generating product recommendation: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to generate recommendation: {str(e)}" + ) + + +@app.post( + "/api/v1/llm/generate/filtered-results", + response_model=ProductRecommendationResponse, + status_code=status.HTTP_200_OK, + summary="Generate filtered results explanation", + description="Generate explanation for filtered search results" +) +async def generate_filtered_results(request: ProductRecommendationRequest): + """ + Generate explanation for filtered search results. + + Wrapper around recommendation generation with forced 'filtered_search' intent. + + Args: + request (ProductRecommendationRequest): Request data. + + Returns: + ProductRecommendationResponse: Generated explanation. + """ + request.intent = "filtered_search" + return await generate_product_recommendation(request) + + +@app.post( + "/api/v1/llm/generate/comparison", + response_model=ProductRecommendationResponse, + status_code=status.HTTP_200_OK, + summary="Generate product comparison", + description="Generate comparison between products" +) +async def generate_comparison(request: ProductRecommendationRequest): + """ + Generate product comparison. + + Wrapper around recommendation generation with forced 'comparison' intent. + + Args: + request (ProductRecommendationRequest): Request data. + + Returns: + ProductRecommendationResponse: Generated comparison. + """ + request.intent = "comparison" + return await generate_product_recommendation(request) + + +@app.get("/", summary="Root endpoint") +async def root(): + """ + Root endpoint with service information. + + Returns: + dict: Basic service info including version and status. + """ + return { + "service": "LLM Service", + "version": "1.0.0", + "status": "running", + "docs": "/docs", + "health": "/health" + } + + +# Application startup +if __name__ == "__main__": + import uvicorn + + logger.info(f"Starting LLM Service on {settings.llm_host}:{settings.llm_port}") + logger.info(f"Deployment phase: {settings.deployment_phase}") + + if settings.is_enterprise_configured(): + logger.info("Provider: GenAI Gateway") + logger.info(f"Simple model: {settings.inference_model_name_simple}") + logger.info(f"Complex model: {settings.inference_model_name_complex}") + else: + logger.warning("Provider: Not configured (GenAI Gateway required)") + + uvicorn.run( + app, + host=settings.llm_host, # nosec B104 - Binding to all interfaces is intentional for Docker container + port=settings.llm_port, + log_level=settings.log_level.lower() + ) + diff --git a/sample_solutions/HybridSearch/api/llm/prompts/complex_qa.txt b/sample_solutions/HybridSearch/api/llm/prompts/complex_qa.txt new file mode 100644 index 00000000..3e19ebb1 --- /dev/null +++ b/sample_solutions/HybridSearch/api/llm/prompts/complex_qa.txt @@ -0,0 +1,27 @@ +You are an expert analyst. Use the provided context to perform deep analysis and reasoning. +Synthesize information across documents to provide comprehensive insights. + +Context: +{context} + +Question: {query} + +Instructions: +1. Analyze information across multiple sources +2. Identify patterns, relationships, and connections +3. Provide well-reasoned conclusions with supporting evidence +4. Cite sources with [Doc ID, Page X] format when referencing information +5. Acknowledge limitations in the available data +6. Compare and contrast when relevant +7. Provide actionable insights when appropriate +8. IMPORTANT: Output ONLY the final analysis. Do NOT output any thinking, reasoning, or internal monologue. +9. Do NOT use first person ("I", "me", "my", "we", "our"). +10. Start your response directly with the analysis. + +Example: +Context: Report A shows 5% growth. Report B shows 3% decline. +Question: Compare the reports. +Analysis: Report A indicates positive growth of 5% [Doc 1], whereas Report B shows a decline of 3% [Doc 2]. This suggests a divergence in performance metrics between the two observed periods. + +Analysis: + diff --git a/sample_solutions/HybridSearch/api/llm/prompts/product_prompts.py b/sample_solutions/HybridSearch/api/llm/prompts/product_prompts.py new file mode 100644 index 00000000..251f9377 --- /dev/null +++ b/sample_solutions/HybridSearch/api/llm/prompts/product_prompts.py @@ -0,0 +1,149 @@ +""" +Product Response Prompts +Prompt templates for different product search response types +""" + +from typing import List, Dict, Any + + +class ProductPrompts: + """Product-specific prompt templates""" + + @staticmethod + def semantic_browse_prompt(query: str, products: List[Dict]) -> str: + """ + Generate prompt for semantic browse (recommendation style) + + Args: + query: User query + products: List of product dictionaries + + Returns: + Formatted prompt string + """ + products_json = "\n".join([ + f"- {p.get('name', 'Unknown')} (${p.get('price', 'N/A')}) - {p.get('description', '')[:100]}... " + f"Rating: {p.get('rating', 'N/A')} stars" + for p in products[:10] + ]) + + return f"""You are a helpful shopping assistant. Given the user's query and search results, +provide a friendly recommendation. Explain WHY each product matches their needs. +Keep it concise - 2-3 sentences per product, focus on relevance to their query. + +User Query: {query} +Top Results: +{products_json} + +Format: Brief intro, then for each recommended product: +- Product name and price +- Why it matches their query +- Key highlight (rating, feature, value) + +Be conversational and helpful. Do not include any internal thought process or monologue. Provide ONLY the final recommendation.""" + + @staticmethod + def filtered_search_prompt(query: str, filters: Dict, products: List[Dict]) -> str: + """ + Generate prompt for filtered search (factual listing) + + Args: + query: User query + filters: Applied filters + products: List of product dictionaries + + Returns: + Formatted prompt string + """ + filter_summary = [] + if filters.get('price_max'): + filter_summary.append(f"under ${filters['price_max']}") + if filters.get('price_min'): + filter_summary.append(f"over ${filters['price_min']}") + if filters.get('rating_min'): + filter_summary.append(f"{filters['rating_min']}+ stars") + if filters.get('categories'): + filter_summary.append(f"in {', '.join(filters['categories'])}") + + filter_text = " and ".join(filter_summary) if filter_summary else "all products" + + products_json = "\n".join([ + f"- {p.get('name', 'Unknown')} - ${p.get('price', 'N/A')} - " + f"Rating: {p.get('rating', 'N/A')} stars ({p.get('review_count', 0)} reviews)" + for p in products[:10] + ]) + + return f"""Present these filtered search results clearly. Confirm the filters applied, +then list products with key details. Be factual and concise. + +Query: {query} +Filters Applied: {filter_text} +Results: +{products_json} + +Format: "Found X products {filter_text}. Here are the top matches:" +Then list with name, price, rating, and one key feature each. Do not include any internal thought process.""" + + @staticmethod + def comparison_prompt(products: List[Dict], priorities: List[str] = None) -> str: + """ + Generate prompt for product comparison + + Args: + products: List of product dictionaries to compare + priorities: User's priorities (e.g., ["price", "rating"]) + + Returns: + Formatted prompt string + """ + products_json = "\n".join([ + f"Product {i+1}: {p.get('name', 'Unknown')}\n" + f" Price: ${p.get('price', 'N/A')}\n" + f" Rating: {p.get('rating', 'N/A')} stars ({p.get('review_count', 0)} reviews)\n" + f" Description: {p.get('description', '')[:150]}...\n" + f" Category: {p.get('category', 'N/A')}" + for i, p in enumerate(products[:5]) + ]) + + priorities_text = f"\nUser's Priorities: {', '.join(priorities)}" if priorities else "" + + return f"""Compare these products objectively. Create a brief comparison highlighting +key differences in price, features, and ratings. Help the user decide. + +Products to Compare: +{products_json}{priorities_text} + +Format: Structured comparison, then a recommendation based on different use cases. +Be objective and helpful. Do not include any internal thought process.""" + + @staticmethod + def quick_results_template(query: str, filters: Dict, products: List[Dict]) -> str: + """ + Generate quick template-based response (no LLM) + + Args: + query: User query + filters: Applied filters + products: List of product dictionaries + + Returns: + Template-based response string + """ + filter_text = "" + if filters.get('price_max'): + filter_text += f" under ${filters['price_max']}" + if filters.get('rating_min'): + filter_text += f" with {filters['rating_min']}+ stars" + + response = f"Here are {len(products)} products matching '{query}'{filter_text}:\n\n" + + for i, product in enumerate(products[:10], 1): + response += f"{i}. {product.get('name', 'Unknown')}\n" + if product.get('price'): + response += f" Price: ${product['price']:.2f}\n" + if product.get('rating'): + response += f" Rating: {product['rating']:.1f} stars ({product.get('review_count', 0)} reviews)\n" + response += "\n" + + return response + diff --git a/sample_solutions/HybridSearch/api/llm/prompts/simple_qa.txt b/sample_solutions/HybridSearch/api/llm/prompts/simple_qa.txt new file mode 100644 index 00000000..696fbee7 --- /dev/null +++ b/sample_solutions/HybridSearch/api/llm/prompts/simple_qa.txt @@ -0,0 +1,25 @@ +You are a precise document assistant. Answer the question using ONLY the provided context. +Be concise and factual. If the answer is not in the context, say "I don't have enough information." + +Context: +{context} + +Question: {query} + +Instructions: +1. Provide a direct, factual answer +2. Cite page numbers in [Page X] format when referencing specific information +3. Quote relevant text when appropriate +4. Be concise and to the point +5. If the information is not in the context, clearly state that +6. IMPORTANT: Output ONLY the final answer. Do NOT output any thinking, reasoning, or internal monologue. +7. Do NOT use first person ("I", "me", "my", "we", "our"). +8. Start your response directly with the answer. + +Example: +Context: The project was completed in 2023. +Question: When was the project finished? +Answer: The project was completed in 2023 [Page 1]. + +Answer: + diff --git a/sample_solutions/HybridSearch/api/llm/requirements.txt b/sample_solutions/HybridSearch/api/llm/requirements.txt new file mode 100644 index 00000000..011e894c --- /dev/null +++ b/sample_solutions/HybridSearch/api/llm/requirements.txt @@ -0,0 +1,30 @@ +# LLM Service Requirements +# API Framework +fastapi==0.104.1 +uvicorn[standard]==0.24.0 +pydantic==2.5.0 +pydantic-settings==2.1.0 + +# OpenAI API +openai>=1.35.0 +httpx>=0.25.0 # Required for OpenAI client +requests>=2.32.0 # Security updates + +# Logging +python-json-logger==2.0.7 + +# Retry Logic +tenacity==8.2.3 + +# Environment +python-dotenv==1.0.0 + +# Text processing +tiktoken==0.5.1 + +# Production Phase Dependencies (will be used later) +# torch==2.1.0 +# transformers==4.35.0 +# intel-extension-for-pytorch==2.1.0 +# vllm==0.2.1 + diff --git a/sample_solutions/HybridSearch/api/llm/services/__init__.py b/sample_solutions/HybridSearch/api/llm/services/__init__.py new file mode 100644 index 00000000..46fbfd98 --- /dev/null +++ b/sample_solutions/HybridSearch/api/llm/services/__init__.py @@ -0,0 +1,2 @@ +"""LLM services module""" + diff --git a/sample_solutions/HybridSearch/api/llm/services/response_formatter.py b/sample_solutions/HybridSearch/api/llm/services/response_formatter.py new file mode 100644 index 00000000..21217317 --- /dev/null +++ b/sample_solutions/HybridSearch/api/llm/services/response_formatter.py @@ -0,0 +1,102 @@ +""" +Response Formatter +Formats product recommendations in quick or explained modes +""" + +import logging +from typing import List, Dict, Any, Optional +from prompts.product_prompts import ProductPrompts + +logger = logging.getLogger(__name__) + +# Query intent constants (matching gateway service) +class QueryIntent: + SEMANTIC_BROWSE = "semantic_browse" + FILTERED_SEARCH = "filtered_search" + HYBRID_SEARCH = "hybrid" + SPECIFIC_PRODUCT = "specific_product" + COMPARISON = "comparison" + + +class ResponseFormatter: + """ + Format product search responses. + + Handles formatting of product data into either quick summaries (template-based) + or detailed explanations (LLM-prompt-ready). + """ + + def __init__(self): + """Initialize response formatter with prompt templates.""" + self.prompts = ProductPrompts() + + def format_response( + self, + query: str, + products: List[Dict], + intent: str, + filters: Dict = None, + mode: str = "explained" + ) -> str: + """ + Format product search response. + + Generates the final response text or prompt input based on the mode. + + Args: + query (str): User's search query. + products (List[Dict]): List of product dictionaries. + intent (str): Detected query intent (e.g., 'comparison', 'filtered_search'). + filters (Dict, optional): Applied filters. + mode (str): Response mode ("quick" or "explained"). + + Returns: + str: Formatted response string (or prompt). + """ + if not products: + return "I couldn't find any products matching your search. Try adjusting your filters or search terms." + + if mode == "quick": + return self.prompts.quick_results_template(query, filters or {}, products) + + # Explained mode - use appropriate prompt based on intent + if intent == QueryIntent.COMPARISON: + return self.prompts.comparison_prompt(products) + elif intent == QueryIntent.FILTERED_SEARCH: + return self.prompts.filtered_search_prompt(query, filters or {}, products) + else: + return self.prompts.semantic_browse_prompt(query, products) + + def should_use_quick_mode( + self, + intent: str, + product_count: int, + has_filters: bool + ) -> bool: + """ + Determine if quick mode should be used. + + Quick mode is preferred for large result sets or explicit filtering, + where an LLM explanation adds latency without much value. + + Args: + intent (str): Query intent. + product_count (int): Number of products found. + has_filters (bool): Whether filters were applied. + + Returns: + bool: True if quick mode should be used, False for detailed explanation. + """ + # Use quick mode for: + # - Filtered searches with clear intent + # - Large result sets (>20 products) + # - Simple queries + + if intent == QueryIntent.FILTERED_SEARCH and has_filters: + return True + + if product_count > 20: + return True + + return False + diff --git a/sample_solutions/HybridSearch/api/retrieval/Dockerfile b/sample_solutions/HybridSearch/api/retrieval/Dockerfile new file mode 100644 index 00000000..5e26fd06 --- /dev/null +++ b/sample_solutions/HybridSearch/api/retrieval/Dockerfile @@ -0,0 +1,35 @@ +# Retrieval Service Dockerfile +FROM python:3.10-slim + +# Set working directory +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements +COPY requirements.txt . + +# Install Python dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code and create non-root user +COPY . . +RUN useradd -m -u 1000 appuser && \ + chown -R appuser:appuser /app && \ + mkdir -p /data/indexes && \ + chown -R appuser:appuser /data +USER appuser + +# Expose port +EXPOSE 8002 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8002/health || exit 1 + +# Run the application +CMD ["python", "main.py"] + diff --git a/sample_solutions/HybridSearch/api/retrieval/api_client.py b/sample_solutions/HybridSearch/api/retrieval/api_client.py new file mode 100644 index 00000000..d7d331dc --- /dev/null +++ b/sample_solutions/HybridSearch/api/retrieval/api_client.py @@ -0,0 +1,175 @@ +""" +API Client for GenAI Gateway authentication and enterprise API calls (Retrieval Service) +""" + +import httpx +import logging +import re +from config import settings + +logger = logging.getLogger(__name__) + + +def clean_url(url: str) -> str: + """ + Remove invisible characters and whitespace from URL. + + Args: + url (str): The URL string to clean. + + Returns: + str: The cleaned URL string. + """ + if not url: + return url + # Remove non-printable characters, whitespace, and specific zero-width chars + return re.sub(r'[\x00-\x1f\x7f-\x9f\s\u200b\u2060\ufeff]+', '', url) + + +class APIClient: + """ + Client for handling GenAI Gateway authentication and API calls. + + Specialized for the Retrieval Service to handle authentication for + enterprise reranking endpoints. + """ + + def __init__(self): + # Use per-model endpoint if set (APISIX/Keycloak), otherwise fall back to GenAI Gateway URL + self.use_apisix = bool(settings.reranker_api_endpoint) + # TEI (Gaudi) does not use /v1 prefix; vLLM (Xeon) does + self.use_tei = settings.inference_backend.lower() == "tei" + base_url = settings.reranker_api_endpoint or settings.genai_gateway_url + self.base_url = clean_url(base_url).rstrip('/') if base_url else None + self.token = settings.genai_api_key + self.http_client = httpx.Client(verify=settings.verify_ssl, timeout=60.0) if self.token else None + + if self.token and self.base_url: + backend = "APISIX" if self.use_apisix else f"GenAI Gateway ({settings.inference_backend})" + logger.info(f"Using {backend} at {self.base_url}") + + def get_rerank_client(self): + """ + Get info for reranking client. + + Returns: + tuple: (client_base_url, token) + """ + if not self.token or not self.base_url: + raise ValueError("GenAI Gateway configuration missing. Check GENAI_GATEWAY_URL and GENAI_API_KEY.") + + client_base_url = f"{self.base_url}" + return client_base_url, self.token + + def rerank_pairs(self, query: str, docs: list[str]) -> list[float]: + """ + Perform reranking using the GenAI Gateway reranking endpoint. + + Args: + query (str): The search query. + docs (list[str]): List of document texts to rerank against the query. + + Returns: + list[float]: List of relevance scores corresponding to the input docs. + + Raises: + Exception: If the reranker API call fails. + """ + if not self.token or not self.base_url: + raise ValueError("GenAI Gateway configuration missing. Check GENAI_GATEWAY_URL and GENAI_API_KEY.") + + # APISIX or TEI (Gaudi): /rerank | GenAI Gateway + vLLM (Xeon): /v1/rerank + use_no_v1 = self.use_apisix or self.use_tei + url = f"{self.base_url}/rerank" if use_no_v1 else f"{self.base_url}/v1/rerank" + + if not self.http_client: + self.http_client = httpx.Client(verify=settings.verify_ssl, timeout=60.0) + + headers = { + "Authorization": f"Bearer {self.token}", + "Content-Type": "application/json" + } + + # Truncate each doc to ~500 chars (~125 tokens) so query + doc + # stays well within the reranker model's 512-token max sequence length. + # 500 chars handles worst-case tokenization (technical text ~2 chars/token) + # and keeps total tokens safely under the model's 512-token max. + max_doc_chars = 500 + truncated_docs = [d[:max_doc_chars] for d in docs] + + # Split into batches to respect the model's max batch size + batch_size = settings.reranker_max_batch_size + scores = [0.0] * len(docs) + + for batch_start in range(0, len(truncated_docs), batch_size): + batch = truncated_docs[batch_start:batch_start + batch_size] + + # Keycloak/APISIX uses "texts"; GenAI Gateway (LiteLLM/Cohere) uses "documents" + if self.use_apisix: + payload = { + "model": settings.reranker_model_name, + "query": query, + "texts": batch, + "top_n": len(batch), + "return_documents": False + } + else: + payload = { + "model": settings.reranker_model_name, + "query": query, + "documents": batch, + "top_n": len(batch), + "return_documents": False + } + + response = self.http_client.post(url, json=payload, headers=headers) + + if response.status_code != 200: + logger.error(f"Reranker API error: {response.status_code} - {response.text}") + response.raise_for_status() + + response_data = response.json() + logger.info(f"Reranker raw response: {response_data}") + + # Handle both response formats: + # Format 1 (vLLM/APISIX): [{"index": 0, "score": 0.9}, ...] + # Format 2 (LiteLLM/Cohere): {"results": [{"index": 0, "relevance_score": 0.9}, ...]} + if isinstance(response_data, list): + results = response_data + else: + results = response_data.get("results", []) + + for res in results: + original_idx = batch_start + res["index"] + if isinstance(response_data, list): + scores[original_idx] = res["score"] + else: + scores[original_idx] = res["relevance_score"] + + return scores + + def is_authenticated(self) -> bool: + """ + Check if client is authenticated. + + Returns: + bool: True if authenticated, False otherwise. + """ + return bool(self.token and self.http_client) + + +# Global instance +_api_client = None + + +def get_api_client(): + """ + Get or create global API client instance. + + Returns: + APIClient: The global singleton instance of APIClient. + """ + global _api_client + if _api_client is None: + _api_client = APIClient() + return _api_client diff --git a/sample_solutions/HybridSearch/api/retrieval/config.py b/sample_solutions/HybridSearch/api/retrieval/config.py new file mode 100644 index 00000000..baad5cc7 --- /dev/null +++ b/sample_solutions/HybridSearch/api/retrieval/config.py @@ -0,0 +1,115 @@ +""" +Retrieval Service Configuration +Manages environment variables and service settings +""" + +from pydantic_settings import BaseSettings +from typing import Optional +from pathlib import Path + +# Compute project root path +_PROJECT_ROOT = Path(__file__).parent.parent.parent +_DEFAULT_INDEX_PATH = str(_PROJECT_ROOT / "data" / "indexes") + + +class Settings(BaseSettings): + """ + Service configuration with environment variable loading. + + Manages: + - Service networking (host/port) + - Embedding service connection + - GenAI Gateway/APISIX Gateway credentials + - Model endpoints (Reranker) + - File paths for Dense (FAISS) and Sparse (BM25) indexes + - Retrieval parameters (top-k, fusion K) + """ + + # Deployment Phase + deployment_phase: str = "development" + + # Service Configuration + retrieval_port: int = 8002 + retrieval_host: str = "0.0.0.0" # nosec B104 - Binding to all interfaces is intentional for Docker container + + # Embedding Service + embedding_service_url: str = "http://localhost:8001" + + # GenAI Gateway Configuration + # Supports multiple deployment patterns: + # - GenAI Gateway: Provide GENAI_GATEWAY_URL and GENAI_API_KEY + # - APISIX Gateway: Provide GENAI_GATEWAY_URL and GENAI_API_KEY + genai_gateway_url: Optional[str] = None + genai_api_key: Optional[str] = None + + # Per-model endpoint URL (required for APISIX, optional for GenAI Gateway) + reranker_api_endpoint: Optional[str] = None + + # Inference backend type: "vllm" (Xeon, default) or "tei" (Gaudi) + # TEI does not use the /v1 path prefix; vLLM does + inference_backend: str = "vllm" + + # Reranker Model Configuration (for Enterprise) + reranker_model_endpoint: str = "bge-reranker-base-vllmcpu" + reranker_model_name: str = "BAAI/bge-reranker-base" + reranker_max_batch_size: int = 32 # Max docs per rerank request (model-dependent) + + # Index Storage Path - default to /data/indexes in Docker + index_storage_path: str = "/data/indexes" + + # Individual Index Paths - can be overridden by environment variables + faiss_index_path: str = "/data/indexes/faiss_index.bin" + bm25_index_path: str = "/data/indexes/bm25_index.pkl" + metadata_index_path: str = "/data/indexes/metadata.pkl" + + # Product Index Paths + product_faiss_index_path: str = "/data/indexes/product_faiss_index.bin" + product_bm25_index_path: str = "/data/indexes/product_bm25_index.pkl" + product_metadata_index_path: str = "/data/indexes/product_metadata.pkl" + + # Retrieval Configuration + top_k_dense: int = 100 + top_k_sparse: int = 100 + top_k_fusion: int = 50 + top_k_rerank: int = 10 + use_reranking: bool = False # Skip in dev phase + rrf_k: int = 60 # RRF constant + + # Product Catalog Settings + system_mode: str = "document" # "document" or "product" + default_result_limit: int = 20 + + # SSL Verification Settings + verify_ssl: bool = True + + # Logging + log_level: str = "INFO" + + class Config: + # Look for .env file in the hybrid-search root directory + env_file = Path(__file__).parent.parent.parent / ".env" + case_sensitive = False + extra = "ignore" # Ignore extra fields in .env file + + def model_post_init(self, __context: any) -> None: + if not self.is_enterprise_configured(): + if self.use_reranking: + # If reranking is enabled, we strictly need GenAI Gateway auth + raise ValueError( + "GenAI Gateway configuration missing for RERANKING. " + "Must provide GENAI_GATEWAY_URL and GENAI_API_KEY in .env file, " + "OR set USE_RERANKING=false." + ) + + def is_enterprise_configured(self) -> bool: + """ + Check if GenAI Gateway is configured. + + Returns: + bool: True if genai_gateway_url and genai_api_key are present. + """ + return bool(self.genai_gateway_url and self.genai_api_key) + + +# Global settings instance +settings = Settings() diff --git a/sample_solutions/HybridSearch/api/retrieval/main.py b/sample_solutions/HybridSearch/api/retrieval/main.py new file mode 100644 index 00000000..05b9c4fb --- /dev/null +++ b/sample_solutions/HybridSearch/api/retrieval/main.py @@ -0,0 +1,640 @@ +""" +Retrieval Service +Hybrid search with FAISS + BM25 + RRF +""" + +import logging +import time +import httpx +from typing import List, Optional, Dict, Any +from fastapi import FastAPI, HTTPException, status +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, Field +from config import settings +from services.dense_retrieval import DenseRetrieval +from services.sparse_retrieval import SparseRetrieval +from services.fusion import ReciprocalRankFusion + +# Configure logging +logging.basicConfig( + level=settings.log_level, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# Initialize FastAPI app +app = FastAPI( + title="Retrieval Service", + description="Hybrid search with dense + sparse + reranking", + version="1.0.0" +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Initialize retrieval components +dense_retrieval = DenseRetrieval( + index_path=settings.faiss_index_path, + metadata_path=settings.metadata_index_path +) + +sparse_retrieval = SparseRetrieval( + index_path=settings.bm25_index_path, + metadata_path=settings.metadata_index_path +) + +# Product retrieval components +product_dense_retrieval = DenseRetrieval( + index_path=settings.product_faiss_index_path, + metadata_path=settings.product_metadata_index_path +) + +product_sparse_retrieval = SparseRetrieval( + index_path=settings.product_bm25_index_path, + metadata_path=settings.product_metadata_index_path +) + +rrf_fusion = ReciprocalRankFusion(k=settings.rrf_k) + + +# Request/Response Models +class HybridRetrievalRequest(BaseModel): + """ + Request model for hybrid retrieval. + + Attributes: + query: Search query string. + top_k_candidates: Number of candidates to fetch from each method (dense/sparse) before fusion. + top_k_fusion: Number of top results to keep after Reciprocal Rank Fusion. + top_k_final: Number of final results to return after reranking (if enabled). + """ + query: str = Field(..., description="Query string") + top_k_candidates: int = Field( + 100, + description="Number of candidates per method before fusion" + ) + top_k_fusion: int = Field( + 50, + description="Number of results after RRF fusion" + ) + top_k_final: int = Field( + 10, + description="Number of final results after reranking" + ) + + class Config: + json_schema_extra = { + "example": { + "query": "What are the key findings?", + "top_k_candidates": 100, + "top_k_fusion": 50, + "top_k_final": 10 + } + } + + +class RetrievalResult(BaseModel): + """ + Single retrieval result model. + + Attributes: + chunk_id: Unique chunk identifier. + document_id: Parent document identifier. + text: Text content of the chunk. + page_number: Page number (optional). + score: Relevance score (from fusion or reranking). + rank: Final rank position (1-based). + retrieval_method: Method that found this result (e.g., 'hybrid', 'dense', 'sparse'). + metadata: Additional metadata dictionary. + """ + chunk_id: str + document_id: str + text: str + page_number: Optional[int] = None + score: float + rank: int + retrieval_method: str + metadata: Dict[str, Any] = {} + + +class HybridRetrievalResponse(BaseModel): + """ + Response model for hybrid retrieval. + + Attributes: + results: List of ranked retrieval results. + retrieval_time_ms: Total retrieval time. + dense_time_ms: Time taken for dense search phase. + sparse_time_ms: Time taken for sparse search phase. + fusion_time_ms: Time taken for fusion phase. + query: Original query. + total_candidates: Total raw candidates found before fusion. + """ + results: List[RetrievalResult] + retrieval_time_ms: float + dense_time_ms: float + sparse_time_ms: float + fusion_time_ms: float + query: str + total_candidates: int + + +class IndexStats(BaseModel): + """Index statistics""" + dense_stats: Dict + sparse_stats: Dict + deployment_phase: str + + +class HealthResponse(BaseModel): + """Health check response""" + status: str + service: str + deployment_phase: str + indexes_loaded: Dict[str, bool] + + +class ProductSearchRequest(BaseModel): + """Request model for product search""" + query_embedding: Optional[List[float]] = Field(None, description="Pre-computed query embedding") + query_text: str = Field(..., description="Query text for BM25") + filters: Dict[str, Any] = Field(default_factory=dict, description="Product filters") + top_k: int = Field(20, description="Number of results to return", ge=1, le=100) + + +class ProductSearchResponse(BaseModel): + """ + Response model for product search. + + Attributes: + results: List of product dictionaries. + total_matches: Total number of matches found. + retrieval_time_ms: Time taken for search in milliseconds. + """ + results: List[Dict] + total_matches: int + retrieval_time_ms: float + + +# Helper Functions +async def get_query_embedding(query: str) -> List[float]: + """ + Get query embedding from embedding service. + + Args: + query (str): Query string to encode. + + Returns: + List[float]: Query embedding vector. + + Raises: + httpx.HTTPError: If embedding service is unreachable or returns error. + """ + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post( + f"{settings.embedding_service_url}/api/v1/embeddings/encode", + json={"texts": [query], "normalize": True} + ) + response.raise_for_status() + data = response.json() + return data["embeddings"][0] + + +# API Endpoints +@app.post( + "/api/v1/retrieve/hybrid", + response_model=HybridRetrievalResponse, + status_code=status.HTTP_200_OK, + summary="Hybrid search", + description="Search using dense + sparse + fusion + reranking" +) +async def hybrid_search(request: HybridRetrievalRequest): + """ + Perform hybrid search using Dense + Sparse + Fusion + Reranking. + + Orchestrates the retrieval pipeline: + 1. Generates query embedding. + 2. Runs parallel dense (FAISS) and sparse (BM25) searches. + 3. Fuses results using Reciprocal Rank Fusion (RRF). + 4. Optionally reranks top results using a cross-encoder model. + + Args: + request (HybridRetrievalRequest): Search parameters. + + Returns: + HybridRetrievalResponse: Ranked search results and timing metrics. + + Raises: + HTTPException: If search fails. + """ + try: + start_time = time.time() + + logger.info(f"Hybrid search query: {request.query[:100]}") + + # Get query embedding + query_embedding = await get_query_embedding(request.query) + + # Dense retrieval + dense_start = time.time() + dense_results = dense_retrieval.search( + query_embedding, + top_k=request.top_k_candidates + ) + dense_time = (time.time() - dense_start) * 1000 + + # Sparse retrieval + sparse_start = time.time() + sparse_results = sparse_retrieval.search( + request.query, + top_k=request.top_k_candidates + ) + sparse_time = (time.time() - sparse_start) * 1000 + + # Fusion + fusion_start = time.time() + fused_results = rrf_fusion.fuse( + dense_results, + sparse_results, + top_k=request.top_k_fusion + ) + + # Reranking (enterprise cross-encoder) + if settings.use_reranking: + final_results = rrf_fusion.rerank( + request.query, + fused_results, + top_k=request.top_k_final + ) + else: + final_results = fused_results[:request.top_k_final] + + fusion_time = (time.time() - fusion_start) * 1000 + + # Calculate total time + total_time = (time.time() - start_time) * 1000 + + # Format results + formatted_results = [ + RetrievalResult(**result) for result in final_results + ] + + logger.info( + f"Hybrid search completed: {len(formatted_results)} results in {total_time:.2f}ms " + f"(dense: {dense_time:.2f}ms, sparse: {sparse_time:.2f}ms, fusion: {fusion_time:.2f}ms)" + ) + + return HybridRetrievalResponse( + results=formatted_results, + retrieval_time_ms=round(total_time, 2), + dense_time_ms=round(dense_time, 2), + sparse_time_ms=round(sparse_time, 2), + fusion_time_ms=round(fusion_time, 2), + query=request.query, + total_candidates=len(dense_results) + len(sparse_results) + ) + + except Exception as e: + logger.error(f"Error during hybrid search: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Hybrid search failed: {str(e)}" + ) + + +@app.post( + "/api/v1/retrieve/dense-only", + response_model=HybridRetrievalResponse, + status_code=status.HTTP_200_OK, + summary="Dense search only", + description="Search using only FAISS dense retrieval" +) +async def dense_only_search(request: HybridRetrievalRequest): + """ + Perform dense-only search (FAISS). + + Used for testing or when keyword matching is not needed. + + Args: + request (HybridRetrievalRequest): Search parameters. + + Returns: + HybridRetrievalResponse: Dense search results. + """ + try: + start_time = time.time() + + query_embedding = await get_query_embedding(request.query) + + dense_start = time.time() + dense_results = dense_retrieval.search( + query_embedding, + top_k=request.top_k_final + ) + dense_time = (time.time() - dense_start) * 1000 + + total_time = (time.time() - start_time) * 1000 + + formatted_results = [ + RetrievalResult(**result) for result in dense_results + ] + + return HybridRetrievalResponse( + results=formatted_results, + retrieval_time_ms=round(total_time, 2), + dense_time_ms=round(dense_time, 2), + sparse_time_ms=0, + fusion_time_ms=0, + query=request.query, + total_candidates=len(dense_results) + ) + + except Exception as e: + logger.error(f"Error during dense search: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Dense search failed: {str(e)}" + ) + + +@app.post( + "/api/v1/retrieve/sparse-only", + response_model=HybridRetrievalResponse, + status_code=status.HTTP_200_OK, + summary="Sparse search only", + description="Search using only BM25 sparse retrieval" +) +async def sparse_only_search(request: HybridRetrievalRequest): + """ + Perform sparse-only search (BM25). + + Used for testing or exact keyword matching. + + Args: + request (HybridRetrievalRequest): Search parameters. + + Returns: + HybridRetrievalResponse: Sparse search results. + """ + try: + start_time = time.time() + + sparse_start = time.time() + sparse_results = sparse_retrieval.search( + request.query, + top_k=request.top_k_final + ) + sparse_time = (time.time() - sparse_start) * 1000 + + total_time = (time.time() - start_time) * 1000 + + formatted_results = [ + RetrievalResult(**result) for result in sparse_results + ] + + return HybridRetrievalResponse( + results=formatted_results, + retrieval_time_ms=round(total_time, 2), + dense_time_ms=0, + sparse_time_ms=round(sparse_time, 2), + fusion_time_ms=0, + query=request.query, + total_candidates=len(sparse_results) + ) + + except Exception as e: + logger.error(f"Error during sparse search: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Sparse search failed: {str(e)}" + ) + + +@app.get( + "/api/v1/retrieve/stats", + response_model=IndexStats, + status_code=status.HTTP_200_OK, + summary="Get index statistics" +) +async def get_stats(): + """ + Get retrieval index statistics. + + Returns: + IndexStats: Statistics for both dense and sparse indexes. + """ + return IndexStats( + dense_stats=dense_retrieval.get_stats(), + sparse_stats=sparse_retrieval.get_stats(), + deployment_phase=settings.deployment_phase + ) + + +@app.post( + "/api/v1/reload", + status_code=status.HTTP_200_OK, + summary="Reload indexes", + description="Reload indexes from disk (useful after clearing or updating indexes)" +) +async def reload_indexes(): + """ + Reload all indexes from disk. + + Useful after clearing or re-ingesting data without restarting the service. + + Returns: + dict: Status message and loaded index flags. + """ + try: + dense_retrieval.reload() + sparse_retrieval.reload() + product_dense_retrieval.reload() + product_sparse_retrieval.reload() + + logger.info("Successfully reloaded all indexes") + + return { + "message": "Indexes reloaded successfully", + "status": "success", + "indexes_loaded": { + "dense": dense_retrieval.index is not None, + "sparse": sparse_retrieval.bm25 is not None, + "dense_vectors": dense_retrieval.index.ntotal if dense_retrieval.index else 0, + "sparse_documents": len(sparse_retrieval.metadata), + "product_dense": product_dense_retrieval.index is not None, + "product_sparse": product_sparse_retrieval.bm25 is not None, + "product_vectors": product_dense_retrieval.index.ntotal if product_dense_retrieval.index else 0 + } + } + except Exception as e: + logger.error(f"Error reloading indexes: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to reload indexes: {str(e)}" + ) + + +@app.get( + "/health", + response_model=HealthResponse, + status_code=status.HTTP_200_OK, + summary="Health check" +) +async def health_check(): + """Health check endpoint""" + return HealthResponse( + status="healthy", + service="retrieval", + deployment_phase=settings.deployment_phase, + indexes_loaded={ + "dense": dense_retrieval.index is not None, + "sparse": sparse_retrieval.bm25 is not None, + "product_dense": product_dense_retrieval.index is not None, + "product_sparse": product_sparse_retrieval.bm25 is not None + } + ) + + +@app.post( + "/api/v1/search/products", + response_model=ProductSearchResponse, + status_code=status.HTTP_200_OK, + summary="Product search", + description="Search products with filters using hybrid retrieval" +) +async def search_products(request: ProductSearchRequest): + """ + Search products with filters using hybrid retrieval. + + Extends hybrid search to support product-specific metadata filtering + and result formatting. + + Args: + request (ProductSearchRequest): Query and filter parameters. + + Returns: + ProductSearchResponse: Formatted product results. + """ + try: + start_time = time.time() + + logger.info(f"Product search: query='{request.query_text[:100]}', filters={request.filters}") + + # Get query embedding if not provided + query_embedding = request.query_embedding + if not query_embedding: + query_embedding = await get_query_embedding(request.query_text) + + # Dense retrieval with filters (uses UNIFIED index, filter by content_type=product) + dense_start = time.time() + dense_results = dense_retrieval.search( + query_embedding, + top_k=request.top_k * 5, # Get more candidates for filtering + filters=request.filters, + product_mode=True + ) + dense_time = (time.time() - dense_start) * 1000 + + # Sparse retrieval with filters (uses UNIFIED index, filter by content_type=product) + sparse_start = time.time() + sparse_results = sparse_retrieval.search( + request.query_text, + top_k=request.top_k * 5, # Get more candidates for filtering + filters=request.filters, + product_mode=True + ) + sparse_time = (time.time() - sparse_start) * 1000 + + # Fusion with enrichment (product mode) + fusion_start = time.time() + fused_results = rrf_fusion.fuse( + dense_results, + sparse_results, + top_k=request.top_k, + product_mode=True, + filters=request.filters + ) + fusion_time = (time.time() - fusion_start) * 1000 + + # Format results for products + product_results = [] + for result in fused_results: + metadata = result.get('metadata', {}) + product_id = metadata.get('product_id') or result.get('document_id') + + # Format as ProductSearchResult + product_result = { + "product_id": product_id, + "name": metadata.get('name', ''), + "description": (result.get('text', '') or metadata.get('description', ''))[:200], + "category": metadata.get('category'), + "price": metadata.get('price'), + "rating": metadata.get('rating'), + "review_count": metadata.get('review_count'), + "image_url": metadata.get('image_url'), + "relevance_score": result.get('relevance_score', result.get('rrf_score', 0.0)), + "match_reasons": result.get('match_reasons', []), + "attributes": {} # Would be populated from product_attributes table + } + product_results.append(product_result) + + total_time = (time.time() - start_time) * 1000 + + logger.info( + f"Product search completed: {len(product_results)} results in {total_time:.2f}ms " + f"(dense: {dense_time:.2f}ms, sparse: {sparse_time:.2f}ms, fusion: {fusion_time:.2f}ms)" + ) + + return ProductSearchResponse( + results=product_results, + total_matches=len(product_results), # This would be calculated before filtering + retrieval_time_ms=round(total_time, 2) + ) + + except Exception as e: + logger.error(f"Error during product search: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Product search failed: {str(e)}" + ) + + +@app.get("/", summary="Root endpoint") +async def root(): + """ + Root endpoint with service information. + + Returns: + dict: Basic service info (version, status). + """ + return { + "service": "Retrieval Service", + "version": "1.0.0", + "status": "running", + "docs": "/docs", + "health": "/health" + } + + +# Application startup +if __name__ == "__main__": + import uvicorn + + logger.info(f"Starting Retrieval Service on {settings.retrieval_host}:{settings.retrieval_port}") + logger.info(f"Deployment phase: {settings.deployment_phase}") + logger.info(f"FAISS index: {settings.faiss_index_path}") + logger.info(f"BM25 index: {settings.bm25_index_path}") + + uvicorn.run( + app, + host=settings.retrieval_host, # nosec B104 - Binding to all interfaces is intentional for Docker container + port=settings.retrieval_port, + log_level=settings.log_level.lower() + ) + diff --git a/sample_solutions/HybridSearch/api/retrieval/requirements.txt b/sample_solutions/HybridSearch/api/retrieval/requirements.txt new file mode 100644 index 00000000..3b18f0ff --- /dev/null +++ b/sample_solutions/HybridSearch/api/retrieval/requirements.txt @@ -0,0 +1,35 @@ +# Retrieval Service Requirements +# API Framework +fastapi==0.104.1 +uvicorn[standard]==0.24.0 +pydantic==2.5.0 +pydantic-settings==2.1.0 + +# Vector Search +faiss-cpu==1.7.4 +numpy==1.24.3 + +# Sparse Search +rank-bm25==0.2.2 + +# HTTP Client (for calling embedding service and GenAI Gateway APIs) +httpx==0.25.1 +requests>=2.32.0 + +# OpenAI SDK (used as client for OpenAI-compatible enterprise APIs) +openai>=1.35.0 + +# Logging +python-json-logger==2.0.7 + +# Environment +python-dotenv==1.0.0 + +# Utilities +python-multipart>=0.0.22 + +# Production Phase Dependencies (will be used later) +# sentence-transformers==2.2.2 # For reranking +# torch==2.1.0 +# transformers==4.35.0 + diff --git a/sample_solutions/HybridSearch/api/retrieval/services/__init__.py b/sample_solutions/HybridSearch/api/retrieval/services/__init__.py new file mode 100644 index 00000000..dc3e39a0 --- /dev/null +++ b/sample_solutions/HybridSearch/api/retrieval/services/__init__.py @@ -0,0 +1,4 @@ +""" +Retrieval Service Modules +""" + diff --git a/sample_solutions/HybridSearch/api/retrieval/services/dense_retrieval.py b/sample_solutions/HybridSearch/api/retrieval/services/dense_retrieval.py new file mode 100644 index 00000000..00557909 --- /dev/null +++ b/sample_solutions/HybridSearch/api/retrieval/services/dense_retrieval.py @@ -0,0 +1,195 @@ +""" +Dense Retrieval using FAISS +Semantic search using vector embeddings +""" + +import logging +import pickle +from pathlib import Path +from typing import List, Dict, Tuple +import numpy as np +import faiss + +logger = logging.getLogger(__name__) + + +class DenseRetrieval: + """ + FAISS-based dense retrieval system. + + Manages loading of FAISS indexes and metadata, and performing semantic search + using vector embeddings. Supports both document and product search modes. + """ + + def __init__( + self, + index_path: str, + metadata_path: str + ): + """ + Initialize dense retrieval. + + Args: + index_path (str): Path to the FAISS index file (.bin). + metadata_path (str): Path to the matching metadata pickle file (.pkl). + """ + self.index_path = Path(index_path) + self.metadata_path = Path(metadata_path) + + self.index = None + self.metadata = [] + + self._load_index() + + def _load_index(self): + """ + Load FAISS index and metadata from disk. + + Handles missing files gracefully by initializing empty state. + """ + try: + if self.index_path.exists(): + self.index = faiss.read_index(str(self.index_path)) + logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors") + else: + logger.warning(f"FAISS index not found at {self.index_path}") + self.index = None + + if self.metadata_path.exists(): + with open(self.metadata_path, 'rb') as f: + self.metadata = pickle.load(f) # nosec B301 - indexes are written by this application + logger.info(f"Loaded {len(self.metadata)} metadata entries") + else: + logger.warning(f"Metadata not found at {self.metadata_path}") + self.metadata = [] + + except Exception as e: + logger.error(f"Error loading index: {e}") + self.index = None + self.metadata = [] + + def search( + self, + query_embedding: List[float], + top_k: int = 100, + filters: Dict = None, + product_mode: bool = False + ) -> List[Dict]: + """ + Search for similar vectors with optional filters. + + Args: + query_embedding (List[float]): Query vector (will be normalized). + top_k (int): Number of results to return. + filters (Dict, optional): Metadata filters (price, rating, category). + product_mode (bool): If True, filters results to only 'product' type items. + + Returns: + List[Dict]: List of result dictionaries containing metadata, score, rank, + and retrieval method. + """ + if not self.index or self.index.ntotal == 0: + logger.warning("Index is empty or not loaded") + return [] + + try: + # Convert to numpy array and normalize + query_vector = np.array([query_embedding], dtype=np.float32) + faiss.normalize_L2(query_vector) + + # Retrieve more candidates if filters are applied (for post-retrieval filtering) + k = min(top_k * 5 if filters else top_k, self.index.ntotal) + distances, indices = self.index.search(query_vector, k) + + # Format results + results = [] + for idx, (distance, index) in enumerate(zip(distances[0], indices[0])): + if index < len(self.metadata): + result = { + **self.metadata[index], + "score": float(distance), # Cosine similarity + "rank": idx + 1, + "retrieval_method": "dense" + } + results.append(result) + + # Filter by content_type if in product_mode + if product_mode: + results = [r for r in results if r.get('content_type') == 'product'] + logger.debug(f"Filtered to {len(results)} product results") + + # Apply other filters if provided + if filters: + results = self._apply_filters(results, filters) + # Limit to top_k after filtering + results = results[:top_k] + + logger.info(f"Dense retrieval found {len(results)} results") + return results + + except Exception as e: + logger.error(f"Error during dense search: {e}") + return [] + + def _apply_filters(self, results: List[Dict], filters: Dict) -> List[Dict]: + """ + Apply post-retrieval metadata filters. + + Args: + results (List[Dict]): List of result dictionaries. + filters (Dict): Filter dictionary (e.g., {'price_min': 10}). + + Returns: + List[Dict]: Filtered list of results. + """ + filtered = [] + + for result in results: + metadata = result.get('metadata', {}) + + # Price filters + if 'price_min' in filters or 'price_max' in filters: + price = metadata.get('price') + if price is None: + continue # Skip products without price if price filter is set + + if 'price_min' in filters and price < filters['price_min']: + continue + if 'price_max' in filters and price > filters['price_max']: + continue + + # Rating filters + if 'rating_min' in filters: + rating = metadata.get('rating') + if rating is None or rating < filters['rating_min']: + continue + + # Category filters + if 'categories' in filters and filters['categories']: + category = metadata.get('category') + if category not in filters['categories']: + continue + + filtered.append(result) + + logger.info(f"Filtered {len(results)} results to {len(filtered)}") + return filtered + + def reload(self): + """Reload index and metadata from disk.""" + logger.info("Reloading FAISS index from disk") + self._load_index() + + def get_stats(self) -> Dict: + """ + Get index statistics. + + Returns: + Dict: Dictionary containing 'total_vectors', 'total_metadata', and 'index_loaded'. + """ + return { + "total_vectors": self.index.ntotal if self.index else 0, + "total_metadata": len(self.metadata), + "index_loaded": self.index is not None + } + diff --git a/sample_solutions/HybridSearch/api/retrieval/services/fusion.py b/sample_solutions/HybridSearch/api/retrieval/services/fusion.py new file mode 100644 index 00000000..d09660df --- /dev/null +++ b/sample_solutions/HybridSearch/api/retrieval/services/fusion.py @@ -0,0 +1,182 @@ +""" +Reciprocal Rank Fusion (RRF) +Combines results from multiple retrieval methods +""" + +import logging +from typing import List, Dict +from collections import defaultdict + +logger = logging.getLogger(__name__) + + +class ReciprocalRankFusion: + """ + RRF algorithm for combining ranked lists. + + Implements Reciprocal Rank Fusion to combine results from multiple retrieval + sources (e.g., Dense and Sparse) into a single ranked list. + """ + + def __init__(self, k: int = 60): + """ + Initialize RRF. + + Args: + k (int): RRF constant (default factor for rank penalty). + Higher k reduces the impact of high rankings. + """ + self.k = k + + def fuse( + self, + dense_results: List[Dict], + sparse_results: List[Dict], + top_k: int = 50, + enrich_results: bool = False, + product_mode: bool = False, + filters: Dict = None + ) -> List[Dict]: + """ + Fuse dense and sparse retrieval results. + + Args: + dense_results (List[Dict]): Results from dense retrieval. + sparse_results (List[Dict]): Results from sparse retrieval. + top_k (int): Number of results to return after fusion. + enrich_results (bool): Whether to add match reasons (for products). + product_mode (bool): Unused flag kept for interface compatibility. + filters (Dict): Unused filters kept for interface compatibility. + + Returns: + List[Dict]: List of fused results sorted by RRF score. + """ + # Calculate RRF scores + rrf_scores = defaultdict(float) + chunk_data = {} # Store chunk info + + # Process dense results + for rank, result in enumerate(dense_results, 1): + chunk_id = result.get("chunk_id") or result.get("metadata", {}).get("product_id") + if chunk_id: + rrf_scores[chunk_id] += 1 / (self.k + rank) + if chunk_id not in chunk_data: + chunk_data[chunk_id] = result + + # Process sparse results + for rank, result in enumerate(sparse_results, 1): + chunk_id = result.get("chunk_id") or result.get("metadata", {}).get("product_id") + if chunk_id: + rrf_scores[chunk_id] += 1 / (self.k + rank) + if chunk_id not in chunk_data: + chunk_data[chunk_id] = result + + # Sort by RRF score + sorted_chunks = sorted( + rrf_scores.items(), + key=lambda x: x[1], + reverse=True + )[:top_k] + + # Format results + fused_results = [] + for rank, (chunk_id, rrf_score) in enumerate(sorted_chunks, 1): + if chunk_id in chunk_data: + result = { + **chunk_data[chunk_id], + "rrf_score": float(rrf_score), + "relevance_score": float(rrf_score), # Alias for product search + "rank": rank, + "retrieval_method": "hybrid" + } + + # Enrich with match reasons for products + if enrich_results: + match_reasons = self._generate_match_reasons(result, dense_results, sparse_results) + result["match_reasons"] = match_reasons + + fused_results.append(result) + + logger.info( + f"RRF fusion: {len(dense_results)} dense + {len(sparse_results)} sparse " + f"→ {len(fused_results)} fused results" + ) + + return fused_results + + def _generate_match_reasons( + self, + result: Dict, + dense_results: List[Dict], + sparse_results: List[Dict] + ) -> List[str]: + """ + Generate match reasons for a product result. + + Analyzes why a product was matched (e.g., semantic match, price match, + high rating, category match). + + Args: + result (Dict): The result dictionary to analyze. + dense_results (List[Dict]): Original dense retrieval results. + sparse_results (List[Dict]): Original sparse retrieval results. + + Returns: + List[str]: List of human-readable match reason strings. + """ + reasons = [] + metadata = result.get('metadata', {}) + product_id = metadata.get('product_id') + + # Check if in dense results (semantic match) + if any(r.get('chunk_id') == result.get('chunk_id') or + r.get('metadata', {}).get('product_id') == product_id + for r in dense_results[:10]): + reasons.append("Semantic match") + + # Check price filter + price = metadata.get('price') + if price is not None: + reasons.append(f"Price: ${price:.2f}") + + # Check rating + rating = metadata.get('rating') + if rating is not None and rating >= 4.0: + reasons.append(f"Highly rated ({rating:.1f} stars)") + + # Check category + category = metadata.get('category') + if category: + reasons.append(f"In {category}") + + return reasons if reasons else ["Matches your search"] + + def rerank( + self, + query: str, + results: List[Dict], + top_k: int = 10 + ) -> List[Dict]: + """ + Rerank results using cross-encoder. + + Wrapper to call the Reranker service. + + Args: + query (str): User query. + results (List[Dict]): Results to rerank. + top_k (int): Number of results to return. + + Returns: + List[Dict]: Top-k reranked results. + """ + from services.reranker import Reranker + reranker = Reranker() + + if not results: + return [] + + # In enterprise mode, we use the cross-encoder reranker + logger.info(f"Performing enterprise reranking for top {len(results)} results") + return reranker.rerank(query, results, top_k) + diff --git a/sample_solutions/HybridSearch/api/retrieval/services/reranker.py b/sample_solutions/HybridSearch/api/retrieval/services/reranker.py new file mode 100644 index 00000000..a0d3a1cd --- /dev/null +++ b/sample_solutions/HybridSearch/api/retrieval/services/reranker.py @@ -0,0 +1,77 @@ +""" +Reranker Service +Handles precise ranking of candidates using cross-encoders +""" + +import logging +import time +from typing import List, Dict, Any +from api_client import get_api_client +from config import settings + +logger = logging.getLogger(__name__) + +class Reranker: + """ + Enterprise Reranker implementation using Keycloak and BGE-Reranker. + + Delegates reranking to the enterprise API client. + """ + + def __init__(self): + self.api_client = get_api_client() + self.enabled = settings.use_reranking + + def rerank(self, query: str, candidates: List[Dict], top_k: int = 10) -> List[Dict]: + """ + Rerank a list of candidates based on the query. + + Sends pairs of (query, document_text) to the enterprise cross-encoder + API to get precise relevance scores. + + Args: + query (str): The user query. + candidates (List[Dict]): List of retrieved chunks to rerank. + top_k (int): Number of results to return. + + Returns: + List[Dict]: Top-k reranked results with updated scores. + """ + if not self.enabled or not candidates: + return candidates[:top_k] + + try: + start_time = time.time() + + # Prepare documents for reranking + # Enterprise cross-encoders typically expect query and doc text pairs + docs = [c.get("text", "") for c in candidates] + + logger.info(f"Reranking {len(docs)} candidates for query: '{query[:50]}...'") + + # Call enterprise rerank endpoint + scores = self.api_client.rerank_pairs(query, docs) + + # Update scores and sort + for i, score in enumerate(scores): + candidates[i]["rerank_score"] = float(score) + # Blend or replace original score + candidates[i]["original_score"] = candidates[i].get("score", 0.0) + candidates[i]["score"] = float(score) + + # Sort by new score + reranked = sorted( + candidates, + key=lambda x: x.get("score", 0.0), + reverse=True + ) + + duration = (time.time() - start_time) * 1000 + logger.info(f"Reranking completed in {duration:.2f}ms") + + return reranked[:top_k] + + except Exception as e: + logger.error(f"Reranking failed: {e}", exc_info=True) + # Fallback to original order + return candidates[:top_k] diff --git a/sample_solutions/HybridSearch/api/retrieval/services/sparse_retrieval.py b/sample_solutions/HybridSearch/api/retrieval/services/sparse_retrieval.py new file mode 100644 index 00000000..3fe9aef3 --- /dev/null +++ b/sample_solutions/HybridSearch/api/retrieval/services/sparse_retrieval.py @@ -0,0 +1,196 @@ +""" +Sparse Retrieval using BM25 +Lexical search using keyword matching +""" + +import logging +import pickle +from pathlib import Path +from typing import List, Dict +from rank_bm25 import BM25Okapi + +logger = logging.getLogger(__name__) + + +class SparseRetrieval: + """ + BM25-based sparse retrieval system. + + Manages loading of pre-computed BM25 indexes and metadata for lexical search. + Supports both document and product search modes. + """ + + def __init__( + self, + index_path: str, + metadata_path: str + ): + """ + Initialize sparse retrieval. + + Args: + index_path (str): Path to BM25 index pickle file. + metadata_path (str): Path to metadata pickle file. + """ + self.index_path = Path(index_path) + self.metadata_path = Path(metadata_path) + + self.bm25 = None + self.metadata = [] + + self._load_index() + + def _load_index(self): + """ + Load BM25 index and metadata from disk. + + Handles missing files gracefully by initializing empty state. + """ + try: + if self.index_path.exists(): + with open(self.index_path, 'rb') as f: + self.bm25 = pickle.load(f) # nosec B301 - indexes are written by this application + logger.info("Loaded BM25 index") + else: + logger.warning(f"BM25 index not found at {self.index_path}") + self.bm25 = None + + if self.metadata_path.exists(): + with open(self.metadata_path, 'rb') as f: + self.metadata = pickle.load(f) # nosec B301 - indexes are written by this application + logger.info(f"Loaded {len(self.metadata)} metadata entries") + else: + logger.warning(f"Metadata not found at {self.metadata_path}") + self.metadata = [] + + except Exception as e: + logger.error(f"Error loading BM25 index: {e}") + self.bm25 = None + self.metadata = [] + + def search( + self, + query: str, + top_k: int = 100, + filters: Dict = None, + product_mode: bool = False + ) -> List[Dict]: + """ + Search using BM25 with optional filters. + + Args: + query (str): Query string. + top_k (int): Number of results to return. + filters (Dict, optional): Metadata filters. + product_mode (bool): If True, filters results to only 'product' type items. + + Returns: + List[Dict]: List of result dictionaries containing metadata, score, rank, + and retrieval method. + """ + if not self.bm25 or not self.metadata: + logger.warning("BM25 index is empty or not loaded") + return [] + + try: + # Tokenize query + tokenized_query = query.lower().split() + + # Get BM25 scores + scores = self.bm25.get_scores(tokenized_query) + + # Get more candidates if filters are applied + k = top_k * 5 if filters else top_k + top_indices = scores.argsort()[-k:][::-1] + + # Format results + results = [] + for rank, idx in enumerate(top_indices, 1): + if idx < len(self.metadata) and scores[idx] > 0: + result = { + **self.metadata[idx], + "score": float(scores[idx]), + "rank": rank, + "retrieval_method": "sparse" + } + results.append(result) + + # Filter by content_type if in product_mode + if product_mode: + results = [r for r in results if r.get('content_type') == 'product'] + logger.debug(f"Filtered to {len(results)} product results") + + # Apply other filters if provided + if filters: + results = self._apply_filters(results, filters) + # Limit to top_k after filtering + results = results[:top_k] + + logger.info(f"Sparse retrieval found {len(results)} results") + return results + + except Exception as e: + logger.error(f"Error during sparse search: {e}") + return [] + + def _apply_filters(self, results: List[Dict], filters: Dict) -> List[Dict]: + """ + Apply post-retrieval metadata filters. + + Args: + results (List[Dict]): List of result dictionaries. + filters (Dict): Filter dictionary. + + Returns: + List[Dict]: Filtered list of results. + """ + filtered = [] + + for result in results: + metadata = result.get('metadata', {}) + + # Price filters + if 'price_min' in filters or 'price_max' in filters: + price = metadata.get('price') + if price is None: + continue # Skip products without price if price filter is set + + if 'price_min' in filters and price < filters['price_min']: + continue + if 'price_max' in filters and price > filters['price_max']: + continue + + # Rating filters + if 'rating_min' in filters: + rating = metadata.get('rating') + if rating is None or rating < filters['rating_min']: + continue + + # Category filters + if 'categories' in filters and filters['categories']: + category = metadata.get('category') + if category not in filters['categories']: + continue + + filtered.append(result) + + logger.info(f"Filtered {len(results)} results to {len(filtered)}") + return filtered + + def reload(self): + """Reload index and metadata from disk.""" + logger.info("Reloading BM25 index from disk") + self._load_index() + + def get_stats(self) -> Dict: + """ + Get index statistics. + + Returns: + Dict: Dictionary containing 'total_documents' and 'bm25_loaded'. + """ + return { + "total_documents": len(self.metadata), + "bm25_loaded": self.bm25 is not None + } + diff --git a/sample_solutions/HybridSearch/data/documents/.gitkeep b/sample_solutions/HybridSearch/data/documents/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/sample_solutions/HybridSearch/data/indexes/.gitkeep b/sample_solutions/HybridSearch/data/indexes/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/sample_solutions/HybridSearch/docker-compose.yml b/sample_solutions/HybridSearch/docker-compose.yml new file mode 100644 index 00000000..0caa7398 --- /dev/null +++ b/sample_solutions/HybridSearch/docker-compose.yml @@ -0,0 +1,223 @@ +services: + ui: + build: + context: ./ui + dockerfile: Dockerfile + container_name: hybrid-search-ui + ports: + - "${UI_PORT:-8501}:8501" + environment: + - GATEWAY_SERVICE_URL=http://gateway:${GATEWAY_PORT:-8000} + - INGESTION_SERVICE_URL=http://ingestion:${INGESTION_PORT:-8004} + - RETRIEVAL_SERVICE_URL=http://retrieval:${RETRIEVAL_PORT:-8002} + - UI_TITLE=${UI_TITLE:-InsightMapper Lite} + - UI_PAGE_ICON=${UI_PAGE_ICON:-📚} + depends_on: + - gateway + networks: + - hybrid-search-network + restart: unless-stopped + healthcheck: + test: ["CMD", "python", "-c", "import socket; s=socket.socket(); s.connect(('gateway', 8000))"] + interval: 30s + timeout: 10s + retries: 3 + + gateway: + build: + context: ./api/gateway + dockerfile: Dockerfile + container_name: hybrid-search-gateway + ports: + - "${GATEWAY_PORT:-8000}:8000" + environment: + - DEPLOYMENT_PHASE=${DEPLOYMENT_PHASE:-development} + - GATEWAY_PORT=8000 + - GATEWAY_LOG_LEVEL=${GATEWAY_LOG_LEVEL:-INFO} + - SYSTEM_MODE=${SYSTEM_MODE:-document} + - EMBEDDING_SERVICE_URL=http://embedding:8001 + - RETRIEVAL_SERVICE_URL=http://retrieval:8002 + - LLM_SERVICE_URL=http://llm:8003 + - INGESTION_SERVICE_URL=http://ingestion:8004 + - VERIFY_SSL=${VERIFY_SSL:-true} + depends_on: + - embedding + - retrieval + - llm + - ingestion + networks: + - hybrid-search-network + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/health"] + interval: 30s + timeout: 10s + retries: 3 + + + embedding: + build: + context: ./api/embedding + dockerfile: Dockerfile + container_name: hybrid-search-embedding + ports: + - "${EMBEDDING_PORT:-8001}:8001" + environment: + - DEPLOYMENT_PHASE=${DEPLOYMENT_PHASE:-development} + - OPENAI_EMBEDDING_MODEL=${OPENAI_EMBEDDING_MODEL:-text-embedding-3-large} + - OPENAI_EMBEDDING_DIMENSIONS=${OPENAI_EMBEDDING_DIMENSIONS:-768} + - EMBEDDING_PORT=8001 + - EMBEDDING_BATCH_SIZE=${EMBEDDING_BATCH_SIZE:-32} + - EMBEDDING_MAX_LENGTH=${EMBEDDING_MAX_LENGTH:-512} + - SYSTEM_MODE=${SYSTEM_MODE:-document} + - EMBEDDING_MODEL_ENDPOINT=${EMBEDDING_MODEL_ENDPOINT:-BAAI/bge-base-en-v1.5} + - EMBEDDING_MODEL_NAME=${EMBEDDING_MODEL_NAME:-BAAI/bge-base-en-v1.5} + # GenAI Gateway / APISIX Configuration + - GENAI_GATEWAY_URL=${GENAI_GATEWAY_URL} + - GENAI_API_KEY=${GENAI_API_KEY} + - EMBEDDING_API_ENDPOINT=${EMBEDDING_API_ENDPOINT:-} + - INFERENCE_BACKEND=${INFERENCE_BACKEND:-vllm} + - VERIFY_SSL=${VERIFY_SSL:-true} + extra_hosts: + - "${LOCAL_URL_ENDPOINT}:host-gateway" + networks: + - hybrid-search-network + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8001/health"] + interval: 30s + timeout: 10s + retries: 3 + + retrieval: + build: + context: ./api/retrieval + dockerfile: Dockerfile + container_name: hybrid-search-retrieval + ports: + - "${RETRIEVAL_PORT:-8002}:8002" + volumes: + - index_data:/data/indexes + environment: + - DEPLOYMENT_PHASE=${DEPLOYMENT_PHASE:-development} + - RETRIEVAL_PORT=8002 + - SYSTEM_MODE=${SYSTEM_MODE:-document} + - FAISS_INDEX_PATH=${FAISS_INDEX_PATH:-/data/indexes/faiss_index.bin} + - BM25_INDEX_PATH=${BM25_INDEX_PATH:-/data/indexes/bm25_index.pkl} + - METADATA_INDEX_PATH=${METADATA_INDEX_PATH:-/data/indexes/metadata.pkl} + - TOP_K_DENSE=${TOP_K_DENSE:-100} + - TOP_K_SPARSE=${TOP_K_SPARSE:-100} + - TOP_K_FUSION=${TOP_K_FUSION:-50} + - TOP_K_RERANK=${TOP_K_RERANK:-10} + - USE_RERANKING=${USE_RERANKING:-false} + - RRF_K=${RRF_K:-60} + - EMBEDDING_SERVICE_URL=http://embedding:8001 + - RERANKER_MODEL_ENDPOINT=${RERANKER_MODEL_ENDPOINT:-BAAI/bge-reranker-base} + - RERANKER_MODEL_NAME=${RERANKER_MODEL_NAME:-BAAI/bge-reranker-base} + - RERANKER_MAX_BATCH_SIZE=${RERANKER_MAX_BATCH_SIZE:-32} + # GenAI Gateway / APISIX Configuration + - GENAI_GATEWAY_URL=${GENAI_GATEWAY_URL} + - GENAI_API_KEY=${GENAI_API_KEY} + - RERANKER_API_ENDPOINT=${RERANKER_API_ENDPOINT:-} + - INFERENCE_BACKEND=${INFERENCE_BACKEND:-vllm} + - VERIFY_SSL=${VERIFY_SSL:-true} + extra_hosts: + - "${LOCAL_URL_ENDPOINT}:host-gateway" + depends_on: + - embedding + networks: + - hybrid-search-network + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8002/health"] + interval: 30s + timeout: 10s + retries: 3 + + + llm: + build: + context: ./api/llm + dockerfile: Dockerfile + container_name: hybrid-search-llm + ports: + - "${LLM_PORT:-8003}:8003" + environment: + - DEPLOYMENT_PHASE=${DEPLOYMENT_PHASE:-development} + - LLM_PORT=8003 + - SYSTEM_MODE=${SYSTEM_MODE:-document} + - MAX_TOKENS_SIMPLE=${MAX_TOKENS_SIMPLE:-512} + - MAX_TOKENS_COMPLEX=${MAX_TOKENS_COMPLEX:-1024} + - TEMPERATURE_SIMPLE=${TEMPERATURE_SIMPLE:-0.1} + - TEMPERATURE_COMPLEX=${TEMPERATURE_COMPLEX:-0.3} + - INFERENCE_MODEL_ENDPOINT_SIMPLE=${LLM_MODEL_ENDPOINT:-Qwen/Qwen3-4B-Instruct-2507} + - INFERENCE_MODEL_NAME_SIMPLE=${LLM_MODEL_NAME:-Qwen/Qwen3-4B-Instruct-2507} + - INFERENCE_MODEL_ENDPOINT_COMPLEX=${LLM_MODEL_ENDPOINT:-Qwen/Qwen3-4B-Instruct-2507} + - INFERENCE_MODEL_NAME_COMPLEX=${LLM_MODEL_NAME:-Qwen/Qwen3-4B-Instruct-2507} + - LLM_MODEL_ENDPOINT=${LLM_MODEL_ENDPOINT:-Qwen/Qwen3-4B-Instruct-2507} + - LLM_MODEL_NAME=${LLM_MODEL_NAME:-Qwen/Qwen3-4B-Instruct-2507} + # GenAI Gateway / APISIX Configuration + - GENAI_GATEWAY_URL=${GENAI_GATEWAY_URL} + - GENAI_API_KEY=${GENAI_API_KEY} + - LLM_API_ENDPOINT=${LLM_API_ENDPOINT:-} + - INFERENCE_BACKEND=${INFERENCE_BACKEND:-vllm} + - VERIFY_SSL=${VERIFY_SSL:-true} + extra_hosts: + - "${LOCAL_URL_ENDPOINT}:host-gateway" + networks: + - hybrid-search-network + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8003/health"] + interval: 30s + timeout: 10s + retries: 3 + + + ingestion: + build: + context: ./api/ingestion + dockerfile: Dockerfile + container_name: hybrid-search-ingestion + ports: + - "${INGESTION_PORT:-8004}:8004" + volumes: + - ./data/documents:/data/documents + - index_data:/data/indexes + - db_data:/data/db + environment: + - DEPLOYMENT_PHASE=${DEPLOYMENT_PHASE:-development} + - INGESTION_PORT=8004 + - SYSTEM_MODE=${SYSTEM_MODE:-document} + - EMBEDDING_SERVICE_URL=http://embedding:8001 + - DOCUMENT_STORAGE_PATH=/data/documents + - INDEX_STORAGE_PATH=/data/indexes + - METADATA_DB_PATH=/data/db/metadata.db + - CHUNK_SIZE=256 + - CHUNK_OVERLAP=25 + - EMBEDDING_BATCH_SIZE=${EMBEDDING_BATCH_SIZE:-32} + - MAX_FILE_SIZE_MB=${MAX_FILE_SIZE_MB:-100} + - SUPPORTED_FORMATS=${SUPPORTED_FORMATS:-pdf,docx,xlsx,ppt,txt} + - 'EMBEDDING_FIELD_TEMPLATE=${EMBEDDING_FIELD_TEMPLATE:-"{name}. {description}. Category: {category}. Brand: {brand}"}' + - EMBEDDING_DIM=${EMBEDDING_DIMENSIONS:-768} + - MAX_PRODUCTS_PER_CATALOG=${MAX_PRODUCTS_PER_CATALOG:-50000} + depends_on: + - embedding + networks: + - hybrid-search-network + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8004/health"] + interval: 30s + timeout: 10s + retries: 3 + +networks: + hybrid-search-network: + driver: bridge + +volumes: + document_storage: + index_data: + db_data: + diff --git a/sample_solutions/HybridSearch/reranker-configuration.md b/sample_solutions/HybridSearch/reranker-configuration.md new file mode 100644 index 00000000..11b4ff86 --- /dev/null +++ b/sample_solutions/HybridSearch/reranker-configuration.md @@ -0,0 +1,257 @@ +# BAAI/bge-reranker-base: Post-Deployment Configuration Workflow + +> **Scope:** GenAI Gateway (LiteLLM) + Xeon deployments only. +> The following do **not** need this guide — the reranker works out of the box: +> - **Keycloak / APISIX** (Xeon or Gaudi): Set `RERANKER_API_ENDPOINT` in `.env` and `USE_RERANKING=true` +> - **GenAI Gateway + Gaudi**: Pre-validated model, no LiteLLM reconfiguration needed + +> **Environment:** GenAI Gateway (LiteLLM) | Xeon | vLLM backend +> **Deployment tool:** `inference-deploy.sh` (Enterprise Inference CLI) +> **Final service:** `bge-reranker-base-cpu-vllm-service.default` + +--- + +## Step 1: Deploy the Reranker Model + +Navigate to the inference directory and run the deployment script: + +```bash +cd /Enterprise-Inference +bash ./inference-deploy.sh +``` + +Follow this menu path: + +``` +> 3 (Update Deployed Inference Cluster) +> 2 (Manage LLM Models) +> 4 (Deploy Model from Hugging Face) + +Enter the HuggingFace Model ID: BAAI/bge-reranker-base +``` + +> The script will warn that the Kubernetes name will be normalized to `bge-reranker-base` (lowercase, hyphens only). This is expected — proceed. + +Wait for the pod to reach `Running` before proceeding: + +```bash +kubectl get pods -n default | grep bge-reranker +# bge-reranker-base-cpu-vllm-XXXXX 1/1 Running 0 2m +``` + +--- + +## Step 2: Set Up Authentication + +```bash +# TOKEN: the litellm_master_key from core/inventory/metadata/vault.yml +# (generated by generate-vault-secrets.sh — not a Keycloak token) +TOKEN="your-vault-token-here" + +# BASE_URL: GenAI Gateway base URL (no /v1 suffix) +# Used for both admin endpoints (/model/info, /model/update) and inference (/v1/rerank) +BASE_URL="https://api.example.com" +``` + +--- + +## Step 3: First Curl Test — Expect a Failure + +Before any configuration changes, test the rerank endpoint. **This will likely fail** because the model is registered with incorrect defaults by the deployment script. + +```bash +curl -k "${BASE_URL}/v1/rerank" \ + -X POST \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer ${TOKEN}" \ + -d '{ + "model": "BAAI/bge-reranker-base", + "query": "What is the name of the dataset introduced in this paper?", + "documents": [ + "The dataset MMHS150K contains 150,000 multimodal tweets.", + "We use GloVe embeddings for text.", + "The paper introduces MMHS150K for hate speech detection." + ], + "top_n": 3, + "return_documents": false + }' +``` + +> **Note:** Use the `"documents"` field for GenAI Gateway (LiteLLM/Cohere provider). For Keycloak/APISIX deployments, use `"texts"` instead — the two backends expect different field names. + +**Expected failure response (OpenAI exception):** +```json +{ + "error": { + "message": "litellm.BadRequestError: ... This model does not support reranking ...", + "type": "invalid_request_error", + "code": 400 + } +} +``` + +> If you see this error, proceed to Step 4. The model is reachable but not yet configured correctly. +> If you get a 404 or connection refused, the pod is not running — go back to Step 1. + +--- + +## Step 4: Get the Model UUID + +The `model/update` curl command requires the internal LiteLLM model ID (a UUID), not the HuggingFace model name. + +**Option A — LiteLLM UI:** + +1. Open the LiteLLM UI → click **Models + Endpoints** in the left sidebar +2. In the model table, locate the row for `BAAI/bge-reranker-base` +3. The **Model ID** column shows a UUID like `77ce7b6e-3f75-4c66-9623-c735d0024e85` +4. Click the row to open the model detail page +5. Switch to the **Raw JSON** tab — the `id` field at the top of `model_info` is your UUID + +**Option B — Curl (works without UI access):** + +```bash +curl -k -s "${BASE_URL}/model/info" \ + -H "Authorization: Bearer ${TOKEN}" | \ + python3 -c " +import sys, json +data = json.load(sys.stdin) +print(f\"{'MODEL NAME':<40} {'API BASE':<50} {'UUID'}\") +print('-' * 110) +for m in data['data']: + name = m.get('model_name', 'N/A') + base = m.get('litellm_params', {}).get('api_base', 'N/A') + uuid = m.get('model_info', {}).get('id', 'N/A') + print(f'{name:<40} {base:<50} {uuid}') +" +``` + +--- + +## Step 5: Run the Model Update Curl Command + +Replace `` with the UUID you copied from Step 4. + +```bash +MODEL_UUID="77ce7b6e-3f75-4c66-9623-c735d0024e85" # ← paste your UUID here + +curl -k -X POST "${BASE_URL}/model/update" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer ${TOKEN}" \ + -d '{ + "model_id": "'"${MODEL_UUID}"'", + "model_name": "BAAI/bge-reranker-base", + "litellm_params": { + "model": "cohere/BAAI/bge-reranker-base", + "custom_llm_provider": "cohere", + "api_base": "http://bge-reranker-base-cpu-vllm-service.default", + "input_cost_per_token": 0.001, + "output_cost_per_token": 0.002, + "use_in_pass_through": true, + "use_litellm_proxy": false, + "merge_reasoning_content_in_choices": false + }, + "model_info": { + "id": "'"${MODEL_UUID}"'", + "db_model": true, + "mode": "rerank", + "input_cost_per_token": 0.001, + "output_cost_per_token": 0.002, + "access_via_team_ids": [], + "direct_access": true, + "key": "cohere/BAAI/bge-reranker-base" + } + }' +``` + +**Fields being corrected by this command:** + +| Field | Before (broken) | After (correct) | +|---|---|---| +| `model` | `openai/BAAI/bge-reranker-base` | `cohere/BAAI/bge-reranker-base` | +| `custom_llm_provider` | `openai` | `cohere` | +| `mode` | *(missing)* | `rerank` | +| `use_in_pass_through` | `false` | `true` | +| `api_base` | *(may include `/v1` suffix)* | `...vllm-service.default` (no `/v1`) | + +> Expected response: HTTP 200 with the updated model JSON echoed back. + +--- + +## Step 6: Verify Changes in the LiteLLM UI + +After the update curl returns 200, go back to the LiteLLM UI and confirm every field updated correctly. + +**Navigation:** Models + Endpoints → click `BAAI/bge-reranker-base` → click **Edit Model** + +Verify the following in the Edit Model form: + +| Field | Required Value | +|---|---| +| **Public Model Name** | `BAAI/bge-reranker-base` | +| **LiteLLM Model Name** | `cohere/BAAI/bge-reranker-base` | +| **Custom LLM Provider** | `cohere` | +| **API Base** | `http://bge-reranker-base-cpu-vllm-service.default` (no `/v1`) | +| **Use In Pass Through** | `true` (toggled on) | +| **Mode** (in Model Info) | `rerank` | + +> The **LiteLLM Model Name** field is the most commonly missed — it must be `cohere/BAAI/bge-reranker-base`, not `BAAI/bge-reranker-base` or `openai/BAAI/bge-reranker-base`. If it did not update, edit it manually here and save. + +--- + +## Step 7: Re-run the Curl — Successful Result + +Run the same curl from Step 3: + +```bash +curl -k "${BASE_URL}/v1/rerank" \ + -X POST \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer ${TOKEN}" \ + -d '{ + "model": "BAAI/bge-reranker-base", + "query": "What is the name of the dataset introduced in this paper?", + "documents": [ + "The dataset MMHS150K contains 150,000 multimodal tweets.", + "We use GloVe embeddings for text.", + "The paper introduces MMHS150K for hate speech detection." + ], + "top_n": 3, + "return_documents": false + }' +``` + +**Expected successful response (GenAI Gateway — LiteLLM/Cohere format):** + +```json +{ + "id": "rerank-...", + "results": [ + { "index": 2, "relevance_score": 0.9412, "document": { "text": "The paper introduces MMHS150K for hate speech detection." } }, + { "index": 0, "relevance_score": 0.8134, "document": { "text": "The dataset MMHS150K contains 150,000 multimodal tweets." } }, + { "index": 1, "relevance_score": 0.0231, "document": { "text": "We use GloVe embeddings for text." } } + ], + "model": "BAAI/bge-reranker-base", + "usage": { "total_tokens": 87 } +} +``` + +> **Note:** For Keycloak/APISIX deployments, the response is a flat array format instead: +> ```json +> [{"index": 2, "score": 0.9274}, {"index": 0, "score": 0.9241}, {"index": 1, "score": 0.0103}] +> ``` + +> Document at index 2 ranks first — it directly answers the query. Document at index 1 (GloVe embeddings) scores near-zero since it is irrelevant. This confirms the reranker is scoring correctly. + +--- + +## Troubleshooting + +| Symptom | Cause | Fix | +|---|---|---| +| Curl returns 404 | Pod not running | `kubectl get pods -n default \| grep bge-reranker` | +| Step 5 returns 404 on `/model/update` | Wrong UUID | Re-copy from Raw JSON tab in UI or re-run Option B curl | +| Step 5 returns 200 but UI shows no changes | `db_model: false` | Ensure `"db_model": true` in `model_info` | +| LiteLLM Model Name still shows `openai/` prefix | Update didn't persist | Edit manually in UI Edit Model form and save | +| Step 7 still throws OpenAI exception | `mode: rerank` not set | Check Raw JSON tab — re-run Step 5 | +| All relevance scores ~0.5 | vLLM loaded but not inferring | `kubectl logs -n default` | +| `use_in_pass_through` not toggling | UI bug | Set via curl in Step 5 and confirm in Raw JSON | diff --git a/sample_solutions/HybridSearch/scripts/download_amazon_dataset.py b/sample_solutions/HybridSearch/scripts/download_amazon_dataset.py new file mode 100644 index 00000000..79cc49dc --- /dev/null +++ b/sample_solutions/HybridSearch/scripts/download_amazon_dataset.py @@ -0,0 +1,165 @@ +""" +Download and prepare Amazon Products dataset from HuggingFace +""" +import pandas as pd +import re +import random +from datasets import load_dataset + +def clean_price(price_str): + """Extract numeric price from string""" + if not price_str or pd.isna(price_str): + return None + + # Remove currency symbols and extract first number + match = re.search(r'[\d,]+\.?\d*', str(price_str)) + if match: + price = match.group().replace(',', '') + try: + return float(price) + except: + return None + return None + +def extract_brand(text): + """Try to extract brand from text""" + if not text or pd.isna(text): + return "Generic" + + # Common brand patterns + text_str = str(text) + words = text_str.split() + if len(words) > 0: + # Take first word as brand (often the brand name) + brand = words[0].strip('.,;:-') + if len(brand) > 2: + return brand + return "Generic" + +def generate_rating(): + """Generate realistic ratings (skewed toward 4-5 stars)""" + weights = [0.05, 0.1, 0.15, 0.35, 0.35] # More 4s and 5s + return round(random.choices([1.0, 2.0, 3.0, 4.0, 5.0], weights=weights)[0] + random.uniform(0, 0.9), 1) + +def generate_review_count(rating): + """Generate review count based on rating""" + # Higher rated products tend to have more reviews + base = random.randint(10, 1000) + multiplier = rating / 5.0 + return int(base * multiplier) + +def clean_text(text): + """Clean text fields""" + if not text or pd.isna(text): + return "" + text_str = str(text).strip() + # Remove excessive whitespace + text_str = ' '.join(text_str.split()) + return text_str[:500] # Limit length + +def simplify_category(category): + """Simplify category to main category""" + if not category or pd.isna(category): + return "General" + + cat_str = str(category) + # Take first category if multiple + if '|' in cat_str: + return cat_str.split('|')[0].strip() + if '>' in cat_str: + return cat_str.split('>')[0].strip() + return cat_str.strip()[:50] + +def download_and_prepare(output_file='../data/test_datasets/amazon_products.csv', max_products=200): + """Download and prepare Amazon dataset""" + + print("📥 Downloading Amazon Products dataset from HuggingFace...") + print("This may take a few minutes...") + + try: + # Load dataset (train split has 24k products) + dataset = load_dataset("ckandemir/amazon-products", split="train", revision="main") # nosec B615 + + print(f"✅ Downloaded {len(dataset)} products") + print("🔄 Converting to DataFrame...") + + # Convert to pandas DataFrame + df = pd.DataFrame(dataset) + + print(f"📊 Columns: {df.columns.tolist()}") + + # Map columns to our schema + products = [] + + print(f"🔄 Processing products (taking first {max_products})...") + + for idx, row in df.head(max_products).iterrows(): + # Extract and clean data + name = clean_text(row.get('Product Name', '')) + if not name or len(name) < 3: + continue + + description = clean_text(row.get('Description', '')) + category = simplify_category(row.get('Category', '')) + price = clean_price(row.get('Selling Price', '')) + + # Skip if no valid price + if not price or price <= 0 or price > 10000: + continue + + image_url = row.get('Image', '') + if not image_url or pd.isna(image_url): + # Use placeholder + image_url = f"https://via.placeholder.com/400x400/3b82f6/ffffff?text=Product" + else: + image_url = str(image_url).strip() + + # Extract or generate additional fields + brand = extract_brand(name) + rating = generate_rating() + review_count = generate_review_count(rating) + + product = { + 'id': f'amz_{idx:05d}', + 'name': name, + 'description': description if description else name, + 'category': category, + 'price': round(price, 2), + 'rating': rating, + 'review_count': review_count, + 'image_url': image_url, + 'brand': brand + } + + products.append(product) + + if (idx + 1) % 50 == 0: + print(f" Processed {idx + 1}/{max_products}...") + + # Create DataFrame and save + products_df = pd.DataFrame(products) + products_df.to_csv(output_file, index=False) + + # Print statistics + print(f"\n✅ Successfully prepared {len(products_df)} products!") + print(f"📁 Saved to: {output_file}") + print(f"\n📊 Statistics:") + print(f" Categories: {products_df['category'].nunique()}") + print(f" Price range: ${products_df['price'].min():.2f} - ${products_df['price'].max():.2f}") + print(f" Avg price: ${products_df['price'].mean():.2f}") + print(f" Rating range: {products_df['rating'].min():.1f} - {products_df['rating'].max():.1f}") + print(f"\n🏷️ Top 10 Categories:") + print(products_df['category'].value_counts().head(10)) + + return products_df + + except Exception as e: + print(f"❌ Error: {e}") + print(f"Make sure you have the required packages:") + print(f" pip install datasets pandas") + raise + +if __name__ == "__main__": + # Download 200 products for testing (you can increase this) + download_and_prepare(max_products=200) + diff --git a/sample_solutions/HybridSearch/scripts/force_reload.py b/sample_solutions/HybridSearch/scripts/force_reload.py new file mode 100644 index 00000000..f1e161f2 --- /dev/null +++ b/sample_solutions/HybridSearch/scripts/force_reload.py @@ -0,0 +1,44 @@ +import httpx +import asyncio +import json + +RETRIEVAL_URL = "http://localhost:8002" +GATEWAY_URL = "http://localhost:8000" + +async def verify_reload(): + async with httpx.AsyncClient() as client: + print(f"1. Triggering reload at {RETRIEVAL_URL}/api/v1/reload...") + try: + response = await client.post(f"{RETRIEVAL_URL}/api/v1/reload") + if response.status_code == 200: + print("Reload successful!") + print(json.dumps(response.json(), indent=2)) + else: + print(f"Reload failed: {response.status_code} - {response.text}") + return + except Exception as e: + print(f"Failed to connect to retrieval service: {e}") + return + + print("\n2. Testing Product Search...") + try: + response = await client.post( + f"{GATEWAY_URL}/api/v1/search", + json={"query": "product", "limit": 5} + ) + if response.status_code == 200: + data = response.json() + results = data.get("results", []) + print(f"Found {len(results)} products.") + if results: + print("First product:", results[0]['name']) + print("SUCCESS: Products are searchable!") + else: + print("FAILURE: No products found after reload.") + else: + print(f"Search failed: {response.status_code} - {response.text}") + except Exception as e: + print(f"Failed to connect to gateway: {e}") + +if __name__ == "__main__": + asyncio.run(verify_reload()) diff --git a/sample_solutions/HybridSearch/scripts/generate_ecommerce_dataset.py b/sample_solutions/HybridSearch/scripts/generate_ecommerce_dataset.py new file mode 100644 index 00000000..d0737637 --- /dev/null +++ b/sample_solutions/HybridSearch/scripts/generate_ecommerce_dataset.py @@ -0,0 +1,160 @@ +""" +Generate a realistic e-commerce product dataset with images +""" +import csv +import random + +# Product templates organized by category +PRODUCTS = { + "Electronics": [ + {"name": "Wireless Bluetooth Headphones", "desc": "Premium noise-canceling headphones with 30-hour battery", "price_range": (49.99, 199.99), "brand": "SoundTech"}, + {"name": "Smartphone 5G", "desc": "Latest 5G smartphone with 6.5\" display and triple camera", "price_range": (399.99, 1299.99), "brand": "TechPro"}, + {"name": "Laptop 15.6 inch", "desc": "Powerful laptop with Intel i7, 16GB RAM, 512GB SSD", "price_range": (699.99, 1899.99), "brand": "CompuMax"}, + {"name": "Wireless Mouse", "desc": "Ergonomic wireless mouse with adjustable DPI", "price_range": (19.99, 79.99), "brand": "TechMouse"}, + {"name": "Mechanical Keyboard", "desc": "RGB backlit mechanical gaming keyboard", "price_range": (59.99, 199.99), "brand": "KeyMaster"}, + {"name": "USB-C Hub", "desc": "7-in-1 USB-C hub with HDMI, USB 3.0, SD card reader", "price_range": (29.99, 89.99), "brand": "ConnectPro"}, + {"name": "Wireless Earbuds", "desc": "True wireless earbuds with active noise cancellation", "price_range": (49.99, 299.99), "brand": "AudioMax"}, + {"name": "Smart Watch", "desc": "Fitness tracker smart watch with heart rate monitor", "price_range": (99.99, 499.99), "brand": "FitWatch"}, + {"name": "Portable Charger", "desc": "20000mAh power bank with fast charging", "price_range": (24.99, 79.99), "brand": "PowerBoost"}, + {"name": "Webcam HD", "desc": "1080p HD webcam with built-in microphone", "price_range": (39.99, 149.99), "brand": "CamPro"}, + {"name": "External SSD", "desc": "1TB portable external SSD with USB 3.2", "price_range": (89.99, 299.99), "brand": "StorageMax"}, + {"name": "Monitor 27 inch", "desc": "4K UHD monitor with HDR and 144Hz refresh rate", "price_range": (299.99, 799.99), "brand": "ViewPro"}, + {"name": "Laptop Backpack", "desc": "Water-resistant laptop backpack with USB charging port", "price_range": (29.99, 89.99), "brand": "TechPack"}, + {"name": "Wireless Charger", "desc": "Fast wireless charging pad for smartphones", "price_range": (19.99, 59.99), "brand": "ChargeFast"}, + {"name": "Bluetooth Speaker", "desc": "Portable waterproof Bluetooth speaker", "price_range": (39.99, 199.99), "brand": "SoundWave"}, + ], + "Home & Kitchen": [ + {"name": "Coffee Maker", "desc": "Programmable drip coffee maker with thermal carafe", "price_range": (39.99, 199.99), "brand": "BrewMaster"}, + {"name": "Blender", "desc": "High-speed blender with multiple settings", "price_range": (49.99, 299.99), "brand": "BlendPro"}, + {"name": "Air Fryer", "desc": "6-quart digital air fryer with preset functions", "price_range": (59.99, 199.99), "brand": "CrispyChef"}, + {"name": "Knife Set", "desc": "Professional 15-piece stainless steel knife set", "price_range": (79.99, 299.99), "brand": "ChefPro"}, + {"name": "Vacuum Cleaner", "desc": "Cordless stick vacuum with HEPA filter", "price_range": (149.99, 599.99), "brand": "CleanMaster"}, + {"name": "Water Bottle", "desc": "Insulated stainless steel water bottle", "price_range": (19.99, 49.99), "brand": "HydroFlask"}, + {"name": "Cookware Set", "desc": "Non-stick 10-piece cookware set", "price_range": (99.99, 399.99), "brand": "CookPro"}, + {"name": "Food Processor", "desc": "12-cup food processor with multiple blades", "price_range": (79.99, 299.99), "brand": "ChopMaster"}, + {"name": "Toaster Oven", "desc": "6-slice convection toaster oven", "price_range": (49.99, 199.99), "brand": "ToastPro"}, + {"name": "Electric Kettle", "desc": "1.7L electric kettle with temperature control", "price_range": (29.99, 99.99), "brand": "BoilFast"}, + {"name": "Mixer Stand", "desc": "6-speed stand mixer with stainless steel bowl", "price_range": (149.99, 499.99), "brand": "MixMaster"}, + {"name": "Cutting Board Set", "desc": "Bamboo cutting board set of 3", "price_range": (24.99, 79.99), "brand": "ChopBoard"}, + {"name": "Storage Containers", "desc": "Glass food storage containers set of 10", "price_range": (29.99, 89.99), "brand": "StoreFresh"}, + {"name": "Dish Rack", "desc": "Stainless steel dish drying rack", "price_range": (24.99, 79.99), "brand": "DryWell"}, + ], + "Sports & Outdoors": [ + {"name": "Yoga Mat", "desc": "Extra thick 6mm yoga mat with carrying strap", "price_range": (19.99, 79.99), "brand": "FitLife"}, + {"name": "Resistance Bands", "desc": "5-piece resistance band set for home workouts", "price_range": (14.99, 49.99), "brand": "FitGear"}, + {"name": "Dumbbells Set", "desc": "Adjustable dumbbell set 5-50 lbs", "price_range": (99.99, 399.99), "brand": "IronFit"}, + {"name": "Jump Rope", "desc": "Speed jump rope with adjustable length", "price_range": (9.99, 29.99), "brand": "FitJump"}, + {"name": "Camping Tent", "desc": "4-person waterproof camping tent", "price_range": (79.99, 299.99), "brand": "OutdoorPro"}, + {"name": "Sleeping Bag", "desc": "Lightweight sleeping bag for camping", "price_range": (39.99, 149.99), "brand": "SleepWell"}, + {"name": "Hiking Backpack", "desc": "50L hiking backpack with rain cover", "price_range": (59.99, 199.99), "brand": "TrailMaster"}, + {"name": "Water Filter", "desc": "Portable water filter for camping", "price_range": (24.99, 79.99), "brand": "PureWater"}, + {"name": "Bike Helmet", "desc": "Adjustable bike helmet with LED light", "price_range": (29.99, 99.99), "brand": "SafeRide"}, + {"name": "Tennis Racket", "desc": "Professional tennis racket with case", "price_range": (49.99, 249.99), "brand": "GamePro"}, + {"name": "Soccer Ball", "desc": "Official size 5 soccer ball", "price_range": (19.99, 59.99), "brand": "KickMaster"}, + {"name": "Swim Goggles", "desc": "Anti-fog swim goggles with UV protection", "price_range": (14.99, 49.99), "brand": "SwimPro"}, + ], + "Clothing & Shoes": [ + {"name": "Running Shoes", "desc": "Lightweight breathable running shoes", "price_range": (59.99, 179.99), "brand": "RunFast"}, + {"name": "Athletic Shorts", "desc": "Quick-dry athletic shorts with pockets", "price_range": (19.99, 49.99), "brand": "FitWear"}, + {"name": "T-Shirt Pack", "desc": "Pack of 3 performance t-shirts", "price_range": (24.99, 79.99), "brand": "ComfortFit"}, + {"name": "Hoodie", "desc": "Fleece pullover hoodie with pockets", "price_range": (29.99, 89.99), "brand": "CozyWear"}, + {"name": "Jeans", "desc": "Slim fit stretch denim jeans", "price_range": (39.99, 129.99), "brand": "DenimPro"}, + {"name": "Sneakers", "desc": "Casual sneakers with memory foam insole", "price_range": (49.99, 149.99), "brand": "StepComfort"}, + {"name": "Winter Jacket", "desc": "Waterproof winter jacket with hood", "price_range": (79.99, 299.99), "brand": "WarmGuard"}, + {"name": "Baseball Cap", "desc": "Adjustable baseball cap with logo", "price_range": (14.99, 39.99), "brand": "CapPro"}, + {"name": "Socks Pack", "desc": "Pack of 6 athletic socks", "price_range": (14.99, 34.99), "brand": "ComfortSocks"}, + {"name": "Backpack", "desc": "School/work backpack with laptop compartment", "price_range": (29.99, 99.99), "brand": "PackPro"}, + ], + "Books & Media": [ + {"name": "Fiction Novel", "desc": "Bestselling fiction novel paperback", "price_range": (9.99, 29.99), "brand": "ReadWell"}, + {"name": "Self-Help Book", "desc": "Personal development and productivity book", "price_range": (12.99, 34.99), "brand": "GrowMind"}, + {"name": "Cookbook", "desc": "Healthy cooking recipes cookbook", "price_range": (14.99, 39.99), "brand": "ChefBook"}, + {"name": "Journal", "desc": "Leather-bound journal with lined pages", "price_range": (12.99, 39.99), "brand": "WriteWell"}, + {"name": "Coloring Book", "desc": "Adult coloring book for relaxation", "price_range": (9.99, 24.99), "brand": "ColorJoy"}, + {"name": "Board Game", "desc": "Family board game for 2-6 players", "price_range": (19.99, 79.99), "brand": "GameNight"}, + {"name": "Puzzle 1000pc", "desc": "1000-piece jigsaw puzzle", "price_range": (14.99, 39.99), "brand": "PuzzleMaster"}, + ], + "Beauty & Personal Care": [ + {"name": "Electric Toothbrush", "desc": "Rechargeable electric toothbrush with timer", "price_range": (29.99, 149.99), "brand": "SmilePro"}, + {"name": "Hair Dryer", "desc": "Ionic hair dryer with diffuser", "price_range": (39.99, 149.99), "brand": "StylePro"}, + {"name": "Moisturizer", "desc": "Daily facial moisturizer with SPF", "price_range": (14.99, 49.99), "brand": "GlowCare"}, + {"name": "Shampoo Set", "desc": "Shampoo and conditioner set", "price_range": (19.99, 59.99), "brand": "HairCare"}, + {"name": "Perfume", "desc": "Luxury eau de parfum spray", "price_range": (39.99, 199.99), "brand": "Essence"}, + {"name": "Makeup Brush Set", "desc": "Professional makeup brush set of 12", "price_range": (24.99, 89.99), "brand": "BeautyPro"}, + {"name": "Face Mask Set", "desc": "Variety pack of sheet face masks", "price_range": (14.99, 39.99), "brand": "SkinCare"}, + {"name": "Electric Shaver", "desc": "Cordless electric shaver for men", "price_range": (49.99, 199.99), "brand": "ShaveMaster"}, + ], + "Toys & Games": [ + {"name": "Building Blocks", "desc": "Creative building blocks set 500 pieces", "price_range": (29.99, 99.99), "brand": "BuildIt"}, + {"name": "RC Car", "desc": "Remote control racing car with rechargeable battery", "price_range": (39.99, 149.99), "brand": "SpeedRacer"}, + {"name": "Doll House", "desc": "Wooden doll house with furniture", "price_range": (49.99, 199.99), "brand": "PlayHome"}, + {"name": "Action Figure", "desc": "Collectible action figure with accessories", "price_range": (14.99, 49.99), "brand": "HeroToys"}, + {"name": "Art Supplies", "desc": "Complete art supplies set for kids", "price_range": (24.99, 79.99), "brand": "ArtKids"}, + {"name": "Science Kit", "desc": "Educational science experiment kit", "price_range": (29.99, 89.99), "brand": "LearnScience"}, + ] +} + +# Placeholder image service (using placeholder.com for realistic URLs) +def get_image_url(product_index, category): + """Generate placeholder image URL""" + colors = ["3498db", "e74c3c", "2ecc71", "f39c12", "9b59b6", "1abc9c"] + color = colors[product_index % len(colors)] + cat_short = category.replace(" & ", "-").replace(" ", "-").lower() + return f"https://via.placeholder.com/400x400/{color}/ffffff?text={cat_short}" + +def generate_dataset(output_file, num_products_per_category=10): + """Generate e-commerce dataset""" + products = [] + product_id = 1 + + for category, product_templates in PRODUCTS.items(): + for i in range(num_products_per_category): + # Select a product template and create variation + template = product_templates[i % len(product_templates)] + + # Add variation to name if repeating + variation_suffix = "" + if i >= len(product_templates): + variations = ["Pro", "Plus", "Max", "Ultra", "Premium", "Deluxe", "Elite"] + variation_suffix = f" {variations[i % len(variations)]}" + + # Generate price within range + price = round(random.uniform(*template["price_range"]), 2) + + # Generate rating (skewed toward 4-5 stars) + rating = round(random.uniform(3.5, 5.0), 1) + + # Generate review count (higher rated products have more reviews) + review_count = int(random.uniform(50, 2000) * (rating / 5.0)) + + product = { + "id": f"prod_{product_id:03d}", + "name": template["name"] + variation_suffix, + "description": template["desc"], + "category": category, + "price": price, + "rating": rating, + "review_count": review_count, + "image_url": get_image_url(product_id, category), + "brand": template["brand"] + } + + products.append(product) + product_id += 1 + + # Write to CSV + with open(output_file, 'w', newline='', encoding='utf-8') as f: + writer = csv.DictWriter(f, fieldnames=['id', 'name', 'description', 'category', 'price', 'rating', 'review_count', 'image_url', 'brand']) + writer.writeheader() + writer.writerows(products) + + print(f"✅ Generated {len(products)} products across {len(PRODUCTS)} categories") + print(f"📁 Saved to: {output_file}") + print(f"💰 Price range: ${min(p['price'] for p in products):.2f} - ${max(p['price'] for p in products):.2f}") + print(f"⭐ Rating range: {min(p['rating'] for p in products):.1f} - {max(p['rating'] for p in products):.1f}") + print(f"📦 Categories: {', '.join(PRODUCTS.keys())}") + +if __name__ == "__main__": + output_file = "../data/test_datasets/ecommerce_products.csv" + generate_dataset(output_file, num_products_per_category=10) + diff --git a/sample_solutions/HybridSearch/scripts/prepare_product_dataset.py b/sample_solutions/HybridSearch/scripts/prepare_product_dataset.py new file mode 100644 index 00000000..5c5a97a5 --- /dev/null +++ b/sample_solutions/HybridSearch/scripts/prepare_product_dataset.py @@ -0,0 +1,186 @@ +""" +Prepare Product Dataset +Download and prepare Amazon products dataset from HuggingFace +""" + +import logging +import pandas as pd +import json +from pathlib import Path +from typing import List, Dict +import sys + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Add parent directory to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +try: + from datasets import load_dataset +except ImportError: + logger.error("datasets library not found. Install with: pip install datasets") + sys.exit(1) + + +def download_dataset(): + """ + Download Amazon products dataset from HuggingFace + + Returns: + Dataset object + """ + logger.info("Downloading Amazon products dataset from HuggingFace...") + try: + dataset = load_dataset("ckandemir/amazon-products", split="train", revision="main") # nosec B615 + logger.info(f"Downloaded dataset with {len(dataset)} products") + return dataset + except Exception as e: + logger.error(f"Error downloading dataset: {e}") + raise + + +def clean_product(product: Dict) -> Dict: + """ + Clean and normalize a product + + Args: + product: Raw product dictionary + + Returns: + Cleaned product dictionary + """ + cleaned = {} + + # Map fields (adjust based on actual dataset structure) + cleaned['name'] = product.get('title') or product.get('name') or product.get('product_name', '') + cleaned['description'] = product.get('description') or product.get('desc') or product.get('details', '') + cleaned['category'] = product.get('category') or product.get('categories', '') + cleaned['price'] = product.get('price') or product.get('cost') or product.get('list_price') + cleaned['rating'] = product.get('rating') or product.get('stars') or product.get('avg_rating') + cleaned['review_count'] = product.get('review_count') or product.get('reviews') or product.get('num_reviews') + cleaned['image_url'] = product.get('image_url') or product.get('image') or product.get('img') + cleaned['brand'] = product.get('brand') or product.get('manufacturer') + + # Generate ID if missing + if not product.get('id') and not product.get('product_id'): + import uuid + cleaned['id'] = f"prod_{uuid.uuid4().hex[:12]}" + else: + cleaned['id'] = product.get('id') or product.get('product_id') + + # Clean price + if cleaned['price']: + try: + if isinstance(cleaned['price'], str): + # Remove currency symbols + price_str = cleaned['price'].replace('$', '').replace(',', '').strip() + cleaned['price'] = float(price_str) if price_str else None + else: + cleaned['price'] = float(cleaned['price']) + except (ValueError, TypeError): + cleaned['price'] = None + + # Clean rating (normalize to 0-5) + if cleaned['rating']: + try: + rating = float(cleaned['rating']) + if rating > 5: + rating = rating / 2.0 # Assume out of 10 + cleaned['rating'] = rating if 0 <= rating <= 5 else None + except (ValueError, TypeError): + cleaned['rating'] = None + + # Clean review count + if cleaned['review_count']: + try: + cleaned['review_count'] = int(cleaned['review_count']) + except (ValueError, TypeError): + cleaned['review_count'] = None + + # Ensure name is not empty + if not cleaned['name']: + cleaned['name'] = f"Product {cleaned['id']}" + + # Ensure description is not empty (use name as fallback) + if not cleaned['description']: + cleaned['description'] = cleaned['name'] + + return cleaned + + +def create_test_subsets(dataset, output_dir: Path): + """ + Create test subsets from dataset + + Args: + dataset: Dataset object + output_dir: Output directory + """ + output_dir.mkdir(parents=True, exist_ok=True) + + # Convert to list of dictionaries + logger.info("Converting dataset to list...") + products = [] + for item in dataset: + cleaned = clean_product(item) + # Only include products with name and description + if cleaned['name'] and cleaned['description']: + products.append(cleaned) + + logger.info(f"Cleaned {len(products)} valid products") + + # Create subsets + subsets = { + 'test_100.csv': 100, + 'test_1000.csv': 1000, + 'test_10000.csv': 10000 + } + + for filename, count in subsets.items(): + if len(products) >= count: + subset = products[:count] + output_path = output_dir / filename + + # Save as CSV + df = pd.DataFrame(subset) + df.to_csv(output_path, index=False) + logger.info(f"Created {filename} with {len(subset)} products") + else: + logger.warning(f"Not enough products for {filename} (have {len(products)}, need {count})") + + # Also save full dataset if requested + if len(products) > 0: + full_path = output_dir / "full_dataset.csv" + df = pd.DataFrame(products) + df.to_csv(full_path, index=False) + logger.info(f"Created full_dataset.csv with {len(products)} products") + + +def main(): + """Main function""" + # Set output directory + script_dir = Path(__file__).parent + project_root = script_dir.parent + output_dir = project_root / "data" / "test_datasets" + + logger.info(f"Output directory: {output_dir}") + + try: + # Download dataset + dataset = download_dataset() + + # Create test subsets + create_test_subsets(dataset, output_dir) + + logger.info("Dataset preparation complete!") + logger.info(f"Test datasets saved to: {output_dir}") + + except Exception as e: + logger.error(f"Error preparing dataset: {e}", exc_info=True) + sys.exit(1) + + +if __name__ == "__main__": + main() + diff --git a/sample_solutions/HybridSearch/scripts/test_product_upload.py b/sample_solutions/HybridSearch/scripts/test_product_upload.py new file mode 100644 index 00000000..88be9bcb --- /dev/null +++ b/sample_solutions/HybridSearch/scripts/test_product_upload.py @@ -0,0 +1,70 @@ +import asyncio +import httpx +import json +import os + +# Configuration +API_URL = "http://localhost:8004" # Ingestion service + +async def verify_upload_fix(): + async with httpx.AsyncClient(timeout=60.0) as client: + print("1. Clearing product catalog...") + response = await client.delete(f"{API_URL}/api/v1/products/clear") + if response.status_code != 200: + print(f"Failed to clear products: {response.text}") + return + print("Products cleared.") + + print("\n2. Uploading Product Catalog (should auto-process)...") + # Create a dummy CSV file with standard headers + with open("test_products.csv", "w") as f: + f.write("id,name,description,category,price,rating,review_count,image_url,brand\n") + f.write("1,Test Product,A test product.,Test Category,10.00,4.5,10,http://example.com/image.jpg,Test Brand\n") + + files = {'file': ('test_products.csv', open('test_products.csv', 'rb'), 'text/csv')} + response = await client.post(f"{API_URL}/api/v1/products/upload", files=files) + + if response.status_code != 202: + print(f"Failed to upload products: {response.text}") + return + + data = response.json() + job_id = data['job_id'] + status = data['status'] + requires_confirmation = data['requires_confirmation'] + + print(f"Upload response status: {status}") + print(f"Requires confirmation: {requires_confirmation}") + + if requires_confirmation: + print("FAILURE: Upload still requires confirmation for standard headers!") + return + + if status != "processing": + print(f"FAILURE: Status should be 'processing', got '{status}'") + return + + print(f"Job started: {job_id}") + + # Poll for completion + print("Waiting for processing...") + for _ in range(10): + await asyncio.sleep(1) + response = await client.get(f"{API_URL}/api/v1/products/status/{job_id}") + job_status = response.json() + print(f"Job status: {job_status['status']} ({job_status['products_processed']}/{job_status['products_total']})") + + if job_status['status'] == 'complete': + print("SUCCESS: Product upload auto-processed successfully!") + break + if job_status['status'] == 'error': + print(f"FAILURE: Job failed with error: {job_status['errors']}") + break + else: + print("FAILURE: Timeout waiting for processing") + + # Cleanup + os.remove("test_products.csv") + +if __name__ == "__main__": + asyncio.run(verify_upload_fix()) diff --git a/sample_solutions/HybridSearch/scripts/verify_separation.py b/sample_solutions/HybridSearch/scripts/verify_separation.py new file mode 100644 index 00000000..c0ba6068 --- /dev/null +++ b/sample_solutions/HybridSearch/scripts/verify_separation.py @@ -0,0 +1,125 @@ +import asyncio +import httpx +import json +import os + +# Configuration +API_URL = "http://localhost:8004" # Ingestion service +RETRIEVAL_URL = "http://localhost:8002" # Retrieval service + +async def verify_separation(): + async with httpx.AsyncClient(timeout=60.0) as client: + print("1. Clearing all indexes...") + response = await client.delete(f"{API_URL}/api/v1/documents/clear-all") + if response.status_code != 200: + print(f"Failed to clear indexes: {response.text}") + return + print("Indexes cleared.") + + print("\n2. Uploading Document (California Drivers License)...") + # Create a dummy PDF file + with open("drivers_license.txt", "w") as f: + f.write("The California Drivers License Handbook covers rules of the road, traffic signs, and safe driving practices.") + + files = {'file': ('drivers_license.txt', open('drivers_license.txt', 'rb'), 'text/plain')} + response = await client.post(f"{API_URL}/api/v1/documents/upload", files=files) + if response.status_code != 202: + print(f"Failed to upload document: {response.text}") + return + doc_id = response.json()['document_id'] + print(f"Document uploaded: {doc_id}") + + # Wait for processing + print("Waiting for document processing...") + await asyncio.sleep(5) + + print("\n3. Uploading Product Catalog (Shoes)...") + # Create a dummy CSV file + with open("products.csv", "w") as f: + f.write("id,name,description,category,price\n") + f.write("1,Running Shoes,High performance running shoes for athletes.,Footwear,99.99\n") + f.write("2,Hiking Boots,Durable boots for rough terrain.,Footwear,129.99\n") + + files = {'file': ('products.csv', open('products.csv', 'rb'), 'text/csv')} + response = await client.post(f"{API_URL}/api/v1/products/upload", files=files) + if response.status_code != 202: + print(f"Failed to upload products: {response.text}") + return + job_data = response.json() + job_id = job_data['job_id'] + print(f"Product upload job started: {job_id}") + + # Confirm mapping + mapping = { + "name": "Product Catalog", + "id_field": "id", + "name_field": "name", + "description_field": "description", + "category_field": "category", + "price_field": "price" + } + + response = await client.post( + f"{API_URL}/api/v1/products/confirm", + data={ + "job_id": job_id, + "field_mapping": json.dumps(mapping) + } + ) + if response.status_code != 202: + print(f"Failed to confirm mapping: {response.text}") + return + print("Product mapping confirmed.") + + # Wait for processing + print("Waiting for product processing...") + await asyncio.sleep(5) + + # Reload indexes + print("\n4. Reloading indexes...") + await client.post(f"{RETRIEVAL_URL}/api/v1/reload") + + print("\n5. Verifying Document Search (Query: 'shoes')...") + # Should NOT find products + response = await client.post( + f"{RETRIEVAL_URL}/api/v1/retrieve/hybrid", + json={ + "query": "shoes", + "top_k_candidates": 10, + "top_k_fusion": 5, + "top_k_final": 5 + } + ) + results = response.json()['results'] + print(f"Found {len(results)} results.") + for res in results: + print(f" - {res.get('text', '')[:50]}... (Source: {res.get('metadata', {}).get('filename', 'Unknown')})") + if "Running Shoes" in res.get('text', ''): + print("FAILURE: Product found in document search!") + return + + print("\n6. Verifying Product Search (Query: 'license')...") + # Should NOT find documents + response = await client.post( + f"{RETRIEVAL_URL}/api/v1/search/products", + json={ + "query_text": "license", + "top_k": 5 + } + ) + results = response.json()['results'] + print(f"Found {len(results)} results.") + for res in results: + print(f" - {res.get('name', '')}: {res.get('description', '')[:50]}...") + if "California Drivers License" in res.get('description', ''): + print("FAILURE: Document found in product search!") + return + + print("\nSUCCESS: Contexts are properly separated!") + + # Cleanup + os.remove("drivers_license.txt") + os.remove("products.csv") + +if __name__ == "__main__": + asyncio.run(verify_separation()) diff --git a/sample_solutions/HybridSearch/scripts/verify_setup.sh b/sample_solutions/HybridSearch/scripts/verify_setup.sh new file mode 100755 index 00000000..d6e76967 --- /dev/null +++ b/sample_solutions/HybridSearch/scripts/verify_setup.sh @@ -0,0 +1,140 @@ +#!/bin/bash + +# Hybrid Search RAG - Setup Verification Script +# This script verifies that the project structure is complete + +set -e + +echo "======================================" +echo "Hybrid Search RAG - Setup Verification" +echo "======================================" +echo "" + +# Color codes +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# Function to check if directory exists +check_dir() { + if [ -d "$1" ]; then + echo -e "${GREEN}✓${NC} Directory exists: $1" + return 0 + else + echo -e "${RED}✗${NC} Directory missing: $1" + return 1 + fi +} + +# Function to check if file exists +check_file() { + if [ -f "$1" ]; then + echo -e "${GREEN}✓${NC} File exists: $1" + return 0 + else + echo -e "${RED}✗${NC} File missing: $1" + return 1 + fi +} + +echo "Checking project structure..." +echo "" + +# Check main directories +echo "Main Directories:" +check_dir "api" +check_dir "ui" +check_dir "data" +check_dir "tests" +check_dir "scripts" +echo "" + +# Check API services +echo "API Services:" +check_dir "api/gateway" +check_dir "api/embedding" +check_dir "api/retrieval" +check_dir "api/llm" +check_dir "api/ingestion" +echo "" + +# Check configuration files +echo "Configuration Files:" +check_file "env.example" +check_file "docker-compose.yml" +check_file ".gitignore" +check_file "README.md" +check_file "IMPLEMENTATION_PLAN.md" +check_file "SETUP_SUMMARY.md" +check_file "architecture.md" +echo "" + +# Check requirements files +echo "Requirements Files:" +check_file "api/gateway/requirements.txt" +check_file "api/embedding/requirements.txt" +check_file "api/retrieval/requirements.txt" +check_file "api/llm/requirements.txt" +check_file "api/ingestion/requirements.txt" +check_file "ui/requirements.txt" +echo "" + +# Check subdirectories +echo "Service Subdirectories:" +check_dir "api/gateway/routers" +check_dir "api/gateway/services" +check_dir "api/retrieval/services" +check_dir "api/llm/models" +check_dir "api/llm/prompts" +check_dir "api/ingestion/services" +check_dir "ui/pages" +check_dir "ui/components" +echo "" + +# Check data directories +echo "Data Directories:" +check_dir "data/documents" +check_dir "data/indexes" +check_file "data/documents/.gitkeep" +check_file "data/indexes/.gitkeep" +echo "" + +# Check for .env file +echo "Environment Configuration:" +if [ -f ".env" ]; then + echo -e "${GREEN}✓${NC} .env file exists" + + # Check if OpenAI API key is set + if grep -q "OPENAI_API_KEY=sk-" .env 2>/dev/null; then + echo -e "${GREEN}✓${NC} OpenAI API key is configured" + elif grep -q "OPENAI_API_KEY=your-openai-api-key-here" .env 2>/dev/null; then + echo -e "${YELLOW}!${NC} OpenAI API key needs to be updated" + else + echo -e "${YELLOW}!${NC} OpenAI API key not found in .env" + fi +else + echo -e "${YELLOW}!${NC} .env file not found (copy from env.example)" +fi +echo "" + +# Summary +echo "======================================" +echo "Verification Complete!" +echo "======================================" +echo "" +echo "Next Steps:" +echo "1. Copy env.example to .env: cp env.example .env" +echo "2. Add your OpenAI API key to .env" +echo "3. Review SETUP_SUMMARY.md for implementation roadmap" +echo "4. Start implementing services in this order:" +echo " a. Embedding Service" +echo " b. LLM Service" +echo " c. Document Ingestion Service" +echo " d. Retrieval Service" +echo " e. Gateway Service" +echo " f. UI Service" +echo "" +echo "See SETUP_SUMMARY.md for detailed implementation guide" +echo "" + diff --git a/sample_solutions/HybridSearch/ui/.streamlit/config.toml b/sample_solutions/HybridSearch/ui/.streamlit/config.toml new file mode 100644 index 00000000..42f97628 --- /dev/null +++ b/sample_solutions/HybridSearch/ui/.streamlit/config.toml @@ -0,0 +1,15 @@ +[server] +port = 8501 +headless = true +enableCORS = false +enableXsrfProtection = true + +[browser] +gatherUsageStats = false + +[theme] +primaryColor = "#667eea" +backgroundColor = "#ffffff" +secondaryBackgroundColor = "#f0f2f6" +textColor = "#262730" +font = "sans serif" diff --git a/sample_solutions/HybridSearch/ui/Dockerfile b/sample_solutions/HybridSearch/ui/Dockerfile new file mode 100644 index 00000000..bdcabb20 --- /dev/null +++ b/sample_solutions/HybridSearch/ui/Dockerfile @@ -0,0 +1,25 @@ +FROM python:3.9-slim + +WORKDIR /app + +# Install dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application and create non-root user +COPY config.py . +COPY app.py . +RUN useradd -m -u 1000 appuser && \ + chown -R appuser:appuser /app +USER appuser + +# Expose Streamlit port +EXPOSE 8501 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8501/_stcore/health || exit 1 + +# Run Streamlit +CMD ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"] + diff --git a/sample_solutions/HybridSearch/ui/app.py b/sample_solutions/HybridSearch/ui/app.py new file mode 100644 index 00000000..c98d8842 --- /dev/null +++ b/sample_solutions/HybridSearch/ui/app.py @@ -0,0 +1,1610 @@ +""" +Streamlit UI for InsightMapper Lite - Hybrid Search RAG Application +Simplified Chat Interface with Document Upload +""" +import streamlit as st +import os +import httpx +import logging +import re +import time +from typing import Dict, Any, List, Optional +from datetime import datetime +import json +from config import settings +from streamlit_keycloak import login + +# Configure logging +logging.basicConfig( + level=getattr(logging, settings.log_level), + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# Page configuration +st.set_page_config( + page_title="RAG Chatbot", + page_icon="💬", + layout="wide", + initial_sidebar_state="collapsed" +) + +# Custom CSS for simplified chat interface +st.markdown(""" + +""", unsafe_allow_html=True) + + +class UIService: + """ + Handle UI operations and API calls. + + Manages communication with the backend services (Gateway, Ingestion, Retrieval) + via HTTP requests. Handles authentication, file uploads, query submission, + and status polling. + """ + + def __init__(self): + """ + Initialize UI Service. + + Sets up API endpoints from environment variables and initializes the HTTP client. + """ + import os + gateway_host = os.getenv("GATEWAY_SERVICE_URL", settings.gateway_service_url) + ingestion_host = os.getenv("INGESTION_SERVICE_URL", "http://localhost:8004") + retrieval_host = os.getenv("RETRIEVAL_SERVICE_URL", "http://localhost:8002") + + self.gateway_url = gateway_host if gateway_host.startswith("http") else f"http://{gateway_host}" + self.ingestion_url = ingestion_host if ingestion_host.startswith("http") else f"http://{ingestion_host}" + self.retrieval_url = retrieval_host if retrieval_host.startswith("http") else f"http://{retrieval_host}" + self.llm_url = "http://localhost:8003" + self.client = httpx.Client(timeout=60.0) + self.token = None + + def set_token(self, token: str): + """ + Set authentication token for client headers. + + Args: + token (str): JWT access token. + """ + self.token = token + self.client.headers.update({"Authorization": f"Bearer {token}"}) + + def check_health(self) -> Dict[str, Any]: + """ + Check health of all services. + + Returns: + Dict[str, Any]: Health status of backend services. + """ + try: + response = self.client.get(f"{self.gateway_url}/api/v1/health/services") + response.raise_for_status() + return response.json() + except Exception as e: + logger.error(f"Health check failed: {e}") + return {"status": "error", "message": str(e)} + + def submit_query(self, query: str, include_debug: bool = False) -> Dict[str, Any]: + """ + Submit a query to the RAG system. + + Args: + query (str): The user's question. + include_debug (bool): Whether to request debug info in the response. + + Returns: + Dict[str, Any]: Normalized response with answer, citations, and metadata. + """ + try: + payload = { + "query": query, + "include_debug_info": include_debug + } + response = self.client.post( + f"{self.gateway_url}/api/v1/query", + json=payload + ) + response.raise_for_status() + data = response.json() + + # Normalize response fields to match UI expectations + normalized = { + "answer": data.get("answer", ""), + "citations": data.get("citations", []), + "query_type": data.get("query_complexity", data.get("query_type", "unknown")), + "model_used": data.get("llm_model", data.get("model_used", "unknown")), + "response_time_ms": data.get("processing_time_ms", data.get("response_time_ms", 0)), + "debug_info": data.get("debug_info"), + "retrieval_results_count": data.get("retrieval_results_count", 0) + } + + return normalized + except httpx.HTTPStatusError as e: + logger.error(f"Query failed with status {e.response.status_code}: {e}") + return { + "error": True, + "message": f"Server error: {e.response.status_code}", + "detail": e.response.text + } + except Exception as e: + logger.error(f"Query failed: {e}") + return {"error": True, "message": str(e)} + + def upload_document(self, file_data: bytes, filename: str) -> Dict[str, Any]: + """ + Upload a document for indexing. + + Args: + file_data (bytes): Raw file content. + filename (str): Name of the file. + + Returns: + Dict[str, Any]: Upload result containing document_id or error info. + """ + try: + # Verify ingestion service is accessible + logger.info(f"Uploading {filename} ({len(file_data)} bytes) to {self.ingestion_url}") + + files = {"file": (filename, file_data, "application/octet-stream")} + + # Use longer timeout for large files + timeout = 120.0 if len(file_data) > 10 * 1024 * 1024 else 60.0 + + with httpx.Client(timeout=timeout) as client: + headers = {"Authorization": f"Bearer {self.token}"} if self.token else {} + response = client.post( + f"{self.ingestion_url}/api/v1/documents/upload", + files=files, + headers=headers + ) + logger.info(f"Upload response status: {response.status_code}") + + if response.status_code != 200 and response.status_code != 202: + error_text = response.text[:500] if response.text else "No error details" + logger.error(f"Upload failed: {response.status_code} - {error_text}") + return { + "error": True, + "message": f"Server error '{response.status_code} {response.reason_phrase}' for url '{self.ingestion_url}/api/v1/documents/upload'", + "detail": error_text + } + + response.raise_for_status() + return response.json() + + except httpx.ConnectError as e: + logger.error(f"Connection error to {self.ingestion_url}: {e}") + return { + "error": True, + "message": f"Cannot connect to ingestion service at {self.ingestion_url}. Is the service running?", + "detail": str(e) + } + except httpx.TimeoutException as e: + logger.error(f"Timeout uploading to {self.ingestion_url}: {e}") + return { + "error": True, + "message": f"Upload timeout. The file may be too large or the service is slow.", + "detail": str(e) + } + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error {e.response.status_code}: {e}") + error_text = e.response.text[:500] if e.response.text else str(e) + return { + "error": True, + "message": f"Server error '{e.response.status_code} {e.response.reason_phrase}' for url '{self.ingestion_url}/api/v1/documents/upload'", + "detail": error_text + } + except Exception as e: + logger.error(f"Document upload failed: {e}", exc_info=True) + return { + "error": True, + "message": f"Upload failed: {str(e)}", + "detail": str(e) + } + + def get_document_status(self, doc_id: str) -> Dict[str, Any]: + """ + Get status of an uploaded document. + + Args: + doc_id (str): Document ID. + + Returns: + Dict[str, Any]: Status info (processing_status, chunk_count, etc.). + """ + try: + response = self.client.get( + f"{self.ingestion_url}/api/v1/documents/{doc_id}/status" + ) + response.raise_for_status() + return response.json() + except httpx.ConnectError as e: + logger.error(f"Connection error to {self.ingestion_url}: {e}") + return {"error": True, "message": f"Cannot connect to ingestion service: {str(e)}"} + except Exception as e: + logger.error(f"Status check failed: {e}") + return {"error": True, "message": str(e)} + + def poll_document_status(self, doc_id: str, max_wait_seconds: int = 120) -> Dict[str, Any]: + """ + Poll document status until completion or timeout. + + Args: + doc_id (str): Document ID. + max_wait_seconds (int): Maximum seconds to wait. + + Returns: + Dict[str, Any]: Final status or timeout error. + """ + start_time = time.time() + while (time.time() - start_time) < max_wait_seconds: + status = self.get_document_status(doc_id) + if "error" in status: + return status + + processing_status = status.get("processing_status", "") + if processing_status in ["completed", "failed"]: + return status + + time.sleep(2) # Poll every 2 seconds + + return {"error": True, "message": "Timeout waiting for document processing"} + + def clear_all_indexes(self) -> Dict[str, Any]: + """ + Clear all vector indexes and metadata. + + Returns: + Dict[str, Any]: Operation result. + """ + try: + # Clear indexes in ingestion service + response = self.client.delete(f"{self.ingestion_url}/api/v1/documents/clear-all") + response.raise_for_status() + logger.info("Successfully cleared all indexes and metadata") + return response.json() + except Exception as e: + logger.error(f"Failed to clear indexes: {e}") + return {"error": True, "message": str(e)} + + def generate_document_summary(self, filename: str, preview_text: str = "") -> Dict[str, Any]: + """ + Generate a summary of the uploaded document. + + Constructs a query to ask the LLM for a summary of the provided text/document. + + Args: + filename (str): Name of the file. + preview_text (str): Optional text content to aid summarization. + + Returns: + Dict[str, Any]: Response containing the summary. + """ + try: + # Create a simple query to summarize the document + query = f"Please provide a brief summary of the document '{filename}'" + + # If we have preview text, use it as context + if preview_text: + response = self.submit_query( + f"Based on this document, provide a brief summary: {preview_text[:1000]}", + include_debug=False + ) + else: + # Otherwise just make a general query + response = self.submit_query( + "What is this document about? Provide a brief overview.", + include_debug=False + ) + + return response + except Exception as e: + logger.error(f"Failed to generate summary: {e}") + return { + "answer": f"Document '{filename}' has been successfully uploaded and indexed. You can now start asking questions about it!", + "citations": [] + } + + # Product Catalog Methods + def get_system_mode(self) -> str: + """ + Get current system mode. + + Returns: + str: 'document' or 'product'. + """ + try: + response = self.client.get(f"{self.ingestion_url}/api/v1/products/mode") + response.raise_for_status() + return response.json().get("mode", "document") + except Exception as e: + logger.error(f"Failed to get system mode: {e}") + return "document" + + def set_system_mode(self, mode: str) -> Dict[str, Any]: + """ + Switch system mode between document and product. + + Args: + mode (str): Target mode ('document' or 'product'). + + Returns: + Dict[str, Any]: Operation result. + """ + try: + response = self.client.post( + f"{self.ingestion_url}/api/v1/products/mode", + data={"mode": mode} + ) + response.raise_for_status() + return response.json() + except Exception as e: + logger.error(f"Failed to set system mode: {e}") + return {"error": True, "message": str(e)} + + def upload_product_catalog(self, file_data: bytes, filename: str) -> Dict[str, Any]: + """ + Upload product catalog CSV/JSON file. + + Args: + file_data (bytes): Raw file content. + filename (str): Name of the file. + + Returns: + Dict[str, Any]: Job info including job_id. + """ + try: + files = {"file": (filename, file_data)} + response = self.client.post( + f"{self.ingestion_url}/api/v1/products/upload", + files=files + ) + response.raise_for_status() + return response.json() + except Exception as e: + logger.error(f"Failed to upload product catalog: {e}") + return {"error": True, "message": str(e)} + + def confirm_product_mapping(self, job_id: str, catalog_name: str, field_mapping: Dict) -> Dict[str, Any]: + """ + Confirm product field mapping and start processing. + + Args: + job_id (str): Ingestion job ID. + catalog_name (str): Name for the catalog. + field_mapping (Dict): Mapping of file columns to standard product fields. + + Returns: + Dict[str, Any]: Confirmation result. + """ + try: + response = self.client.post( + f"{self.ingestion_url}/api/v1/products/confirm", + data={ + "job_id": job_id, + "catalog_name": catalog_name, + "field_mapping": json.dumps(field_mapping) + } + ) + response.raise_for_status() + return response.json() + except Exception as e: + logger.error(f"Failed to confirm product mapping: {e}") + return {"error": True, "message": str(e)} + + def get_product_ingestion_status(self, job_id: str) -> Dict[str, Any]: + """ + Get product ingestion job status. + + Args: + job_id (str): Job ID. + + Returns: + Dict[str, Any]: Job status. + """ + try: + response = self.client.get( + f"{self.ingestion_url}/api/v1/products/status/{job_id}" + ) + response.raise_for_status() + return response.json() + except Exception as e: + logger.error(f"Failed to get ingestion status: {e}") + return {"error": True, "message": str(e)} + + def get_catalog_info(self) -> Dict[str, Any]: + """ + Get current catalog information. + + Returns: + Dict[str, Any]: Catalog statistics (product count, categories). + """ + try: + response = self.client.get(f"{self.ingestion_url}/api/v1/products/catalog/info") + response.raise_for_status() + return response.json() + except Exception as e: + logger.error(f"Failed to get catalog info: {e}") + return {"loaded": False, "message": str(e)} + + def get_all_products(self, limit: int = 100) -> List[Dict[str, Any]]: + """ + Get all products from catalog. + + Args: + limit (int): Maximum number of products to return. + + Returns: + List[Dict[str, Any]]: List of product dictionaries. + """ + try: + # Use a generic query to get all products + response = self.client.post( + f"{self.gateway_url}/api/v1/search", + json={"query": "product", "limit": limit} + ) + response.raise_for_status() + result = response.json() + return result.get("results", []) + except Exception as e: + logger.error(f"Failed to get products: {e}") + return [] + + def search_products(self, query: str, filters: Optional[Dict] = None, limit: int = 100) -> Dict[str, Any]: + """ + Search products using natural language query. + + Args: + query (str): Search query. + filters (Optional[Dict]): Filters to apply. + limit (int): Max results. + + Returns: + Dict[str, Any]: Search results and interpreted filters. + """ + try: + payload = { + "query": query, + "limit": limit + } + if filters: + payload["filters"] = filters + + response = self.client.post( + f"{self.gateway_url}/api/v1/search", + json=payload + ) + response.raise_for_status() + return response.json() + except Exception as e: + logger.error(f"Product search failed: {e}") + return {"error": True, "message": str(e), "results": []} + + def clear_product_catalog(self) -> Dict[str, Any]: + """ + Clear all products from catalog. + + Returns: + Dict[str, Any]: Operation result. + """ + try: + response = self.client.delete(f"{self.ingestion_url}/api/v1/products/clear") + response.raise_for_status() + return response.json() + except Exception as e: + logger.error(f"Failed to clear product catalog: {e}") + return {"error": True, "message": str(e)} + + def reload_retrieval_indexes(self) -> Dict[str, Any]: + """ + Reload retrieval indexes. + + Forces the retrieval service to reload indexes from disk. + + Returns: + Dict[str, Any]: Operation result. + """ + try: + logger.info(f"Reloading indexes at {self.retrieval_url}") + response = self.client.post(f"{self.retrieval_url}/api/v1/reload") + response.raise_for_status() + return response.json() + except Exception as e: + logger.error(f"Failed to reload indexes: {e}") + return {"error": True, "message": str(e)} + + +def initialize_session_state(): + """ + Initialize session state variables. + + Sets up default values for chat history, UI service, document status, + and product catalog state if they don't exist. + """ + if "chat_history" not in st.session_state: + st.session_state.chat_history = [] + if "ui_service" not in st.session_state: + st.session_state.ui_service = UIService() + if "current_document" not in st.session_state: + st.session_state.current_document = None + if "document_ready" not in st.session_state: + st.session_state.document_ready = False + if "active_citations" not in st.session_state: + st.session_state.active_citations = {} + if "upload_status" not in st.session_state: + st.session_state.upload_status = None + # Product catalog state + if "system_mode" not in st.session_state: + st.session_state.system_mode = "document" + if "catalog_loaded" not in st.session_state: + st.session_state.catalog_loaded = False + if "catalog_info" not in st.session_state: + st.session_state.catalog_info = None + if "product_upload_job" not in st.session_state: + st.session_state.product_upload_job = None + if "all_products" not in st.session_state: + st.session_state.all_products = [] + if "filtered_products" not in st.session_state: + st.session_state.filtered_products = [] + if "search_query" not in st.session_state: + st.session_state.search_query = "" + if "applied_filters" not in st.session_state: + st.session_state.applied_filters = {} + + +def process_citations_in_text(answer: str, citations: List[Dict]) -> str: + """ + Process citation markers in answer text. + + Replaces citation markers (e.g., [Page X]) with styled HTML badges. + + Args: + answer (str): Text response from LLM. + citations (List[Dict]): List of citation objects. + + Returns: + str: HTML-formatted answer with styled citations. + """ + # Find all citation patterns: [Page X], [Page X-Y], [Doc, Page X], or [X] + citation_pattern = r'\[(Page \d+(?:-\d+)?|[^\]]+, Page \d+|\d+)\]' + + def replace_citation(match): + citation_text = match.group(1) + return f'[{citation_text}]' + + # Replace all citations + processed_answer = re.sub(citation_pattern, replace_citation, answer) + return processed_answer + + +def render_header(): + """ + Render page header with mode switcher. + + Displays title and buttons to switch between 'RAG Chatbot' (Document) + and 'Product Catalog Search' modes. + """ + col1, col2 = st.columns([4, 1]) + with col1: + if st.session_state.system_mode == "product": + st.markdown('
🛍️ Product Catalog Search
', unsafe_allow_html=True) + st.markdown('
Search products with natural language queries
', unsafe_allow_html=True) + else: + st.markdown('
💬 RAG Chatbot
', unsafe_allow_html=True) + st.markdown('
Ask questions about your documents
', unsafe_allow_html=True) + with col2: + # Mode switcher + current_mode = st.session_state.system_mode + if current_mode == "document": + if st.button("🛍️ Switch to Products", use_container_width=True): + result = st.session_state.ui_service.set_system_mode("product") + if not result.get("error"): + st.session_state.system_mode = "product" + # Re-check catalog status instead of setting to False + st.session_state.catalog_info = None + st.session_state.all_products = [] + st.session_state.filtered_products = [] + st.rerun() + else: + if st.button("📄 Switch to Documents", use_container_width=True): + result = st.session_state.ui_service.set_system_mode("document") + if not result.get("error"): + st.session_state.system_mode = "document" + st.session_state.all_products = [] + st.session_state.filtered_products = [] + st.rerun() + + +def render_upload_panel(): + """ + Render left panel with document upload. + + Handles file upload widget, status display, and processing feedback loop. + """ + st.markdown('
', unsafe_allow_html=True) + + # Section header + st.markdown('
📄 Upload Document
', unsafe_allow_html=True) + st.markdown('
Upload a PDF to start asking questions
', unsafe_allow_html=True) + + # Show upload status + if not st.session_state.current_document: + st.markdown( + '
⚠️ No document uploaded
', + unsafe_allow_html=True + ) + else: + doc_name = st.session_state.current_document.get("filename", "Unknown") + chunk_count = st.session_state.current_document.get("chunk_count", 0) + st.markdown( + f'
{doc_name}
' + f'{chunk_count} chunks indexed
', + unsafe_allow_html=True + ) + + # Upload interface + st.markdown(''' +
+
📤
+
Drop your PDF here
+
or
+
+ ''', unsafe_allow_html=True) + + uploaded_file = st.file_uploader( + "Choose a file", + type=["pdf", "docx", "txt"], + help="Supported formats: PDF, DOCX, TXT (max 100MB per file)", + label_visibility="collapsed", + key="file_uploader" + ) + + if uploaded_file: + st.info(f"📄 {uploaded_file.name} ({uploaded_file.size / 1024:.1f} KB)") + + upload_button = st.button( + "🚀 Upload", + type="primary", + use_container_width=True, + disabled=(uploaded_file is None) + ) + + if upload_button and uploaded_file is not None: + # Create placeholder for status updates + status_placeholder = st.empty() + progress_bar = st.progress(0) + + # Clear existing indexes silently (single document mode) + # This ensures we always start fresh with just one document + status_placeholder.info("🗑️ Clearing previous data...") + progress_bar.progress(10) + st.session_state.ui_service.clear_all_indexes() + + # Upload document + status_placeholder.info("⬆️ Uploading document...") + progress_bar.progress(20) + + result = st.session_state.ui_service.upload_document( + uploaded_file.read(), + uploaded_file.name + ) + + if "error" in result: + status_placeholder.error(f"❌ Upload failed: {result['message']}") + progress_bar.empty() + else: + doc_id = result.get("document_id", result.get("doc_id")) + + # Poll for processing status + status_placeholder.info("🔄 Processing document...") + progress_bar.progress(40) + + max_attempts = 60 + attempt = 0 + while attempt < max_attempts: + status_info = st.session_state.ui_service.get_document_status(doc_id) + + if "error" in status_info: + status_placeholder.error(f"❌ Status check failed: {status_info['message']}") + break + + processing_status = status_info.get("processing_status", "unknown") + chunk_count = status_info.get("chunk_count", 0) + + if processing_status == "completed": + progress_bar.progress(100) + status_placeholder.success(f"✅ Document processed! ({chunk_count} chunks indexed)") + + # Reload retrieval service to pick up new indexes + try: + logger.info("Reloading retrieval service with new document data") + retrieval_url = "http://localhost:8002" + reload_response = st.session_state.ui_service.client.post(f"{retrieval_url}/api/v1/reload") + reload_response.raise_for_status() + logger.info("Successfully reloaded retrieval service after document upload") + except Exception as reload_error: + logger.warning(f"Failed to reload retrieval service: {reload_error}") + + # Store document info + st.session_state.current_document = { + "doc_id": doc_id, + "filename": uploaded_file.name, + "timestamp": datetime.now().isoformat(), + "chunk_count": chunk_count + } + st.session_state.document_ready = True + + # Reload retrieval indexes so the new document can be found + with st.spinner("Reloading search indexes..."): + st.session_state.ui_service.reload_retrieval_indexes() + + # Generate summary and add as first message + with st.spinner("Generating document summary..."): + summary_response = st.session_state.ui_service.generate_document_summary( + uploaded_file.name + ) + + # Clear chat and add welcome message + st.session_state.chat_history = [] + + welcome_message = { + "type": "assistant", + "response": summary_response, + "timestamp": datetime.now().strftime("%I:%M %p"), + "id": "welcome_message", + "is_welcome": True + } + st.session_state.chat_history.append(welcome_message) + + time.sleep(1) + status_placeholder.empty() + progress_bar.empty() + st.rerun() + break + + elif processing_status == "failed": + error_msg = status_info.get("error_message", "Unknown error") + status_placeholder.error(f"❌ Processing failed: {error_msg}") + progress_bar.empty() + break + + elif processing_status == "processing": + progress = min(40 + (attempt * 50 // max_attempts), 90) + progress_bar.progress(progress) + status_placeholder.info(f"🔄 Processing document... ({attempt * 2}s)") + + time.sleep(2) + attempt += 1 + + if attempt >= max_attempts: + status_placeholder.warning("⏱️ Processing is taking longer than expected.") + progress_bar.empty() + + # Instructions + st.markdown(''' +
+
Instructions:
+
    +
  • Upload a PDF document (max 100MB)
  • +
  • Wait for processing to complete
  • +
  • Start asking questions in the chat
  • +
  • Get intelligent answers based on your document
  • +
+
+ ''', unsafe_allow_html=True) + + st.markdown('
', unsafe_allow_html=True) # Close upload-panel + + +def render_chat_panel(): + """ + Render right panel with chat interface. + + Displays chat history, empty state (if no doc), and chat input. + Handles message submission and response rendering. + """ + st.markdown('
', unsafe_allow_html=True) + + # Section header + st.markdown('
💬 Chat Assistant
', unsafe_allow_html=True) + st.markdown('
Upload a document to start chatting
', unsafe_allow_html=True) + + if not st.session_state.document_ready: + # Empty state + st.markdown(''' +
+
🤖
+
No Document Loaded
+
Upload a PDF document on the left to start asking questions and get intelligent answers powered by AI
+
+ ''', unsafe_allow_html=True) + else: + + # Chat messages container + chat_container = st.container() + + with chat_container: + if st.session_state.chat_history: + for message in st.session_state.chat_history: + render_chat_message(message) + + # Chat input at bottom + st.markdown("---") + + col1, col2 = st.columns([5, 1]) + + with col1: + query = st.text_input( + "Type your question...", + placeholder="Upload a document first...", + key="chat_input", + label_visibility="collapsed" + ) + + with col2: + submit_button = st.button("📤 Send", type="primary", use_container_width=True) + + # Help text + st.caption("Press Enter to send • The AI will answer based on your uploaded document") + + # Process query + if submit_button and query.strip(): + # Add user message + user_message = { + "type": "user", + "content": query, + "timestamp": datetime.now().strftime("%I:%M %p"), + "id": f"user_{len(st.session_state.chat_history)}" + } + st.session_state.chat_history.append(user_message) + + # Get response + with st.spinner("🤔 Thinking..."): + response = st.session_state.ui_service.submit_query(query, include_debug=False) + + # Add assistant message + assistant_message = { + "type": "assistant", + "response": response, + "timestamp": datetime.now().strftime("%I:%M %p"), + "id": f"assistant_{len(st.session_state.chat_history)}" + } + st.session_state.chat_history.append(assistant_message) + + st.rerun() + + elif submit_button: + st.warning("⚠️ Please enter a question") + + st.markdown('
', unsafe_allow_html=True) # Close chat-panel + + +def render_chat_message(message: Dict[str, Any]): + """ + Render a single chat message (user or assistant). + + Args: + message (Dict[str, Any]): Message object containing type, content/response, etc. + """ + message_type = message.get("type", "assistant") + + if message_type == "user": + # User message bubble + st.markdown( + f'
{message["content"]}
', + unsafe_allow_html=True + ) + else: + # Assistant message + response = message.get("response", {}) + + if "error" in response: + st.error(f"❌ {response.get('message', 'An error occurred')}") + return + + # Get answer and process citations + answer = response.get("answer", "No answer generated") + citations = response.get("citations", []) + + # Check if this is the welcome message + is_welcome = message.get("is_welcome", False) + + if is_welcome: + # Format welcome message differently + welcome_text = f""" +
+
📄 Document Summary
+
{answer}
+
+ 💡 Let me know how I can help you with this document! +
+
+ """ + st.markdown(welcome_text, unsafe_allow_html=True) + else: + # Regular message with citations + if citations: + processed_answer = process_citations_in_text(answer, citations) + else: + processed_answer = answer + + st.markdown(f'
{processed_answer}
', unsafe_allow_html=True) + + # Show sources if available + if citations: + with st.expander(f"📚 View {len(citations)} Source(s)", expanded=False): + for i, citation in enumerate(citations[:5], 1): + doc_id = citation.get('document_id', 'N/A') + page_num = citation.get('page_number', 'N/A') + snippet = citation.get('relevant_text_snippet', '') + + st.markdown(f"**[{i}]** Page {page_num}") + if snippet: + st.text(snippet[:200] + "..." if len(snippet) > 200 else snippet) + st.markdown("---") + + + + +def render_catalog_sidebar(): + """ + Render compact sidebar for catalog management. + + Displays catalog status, clear button, and file uploader for new catalogs. + """ + with st.sidebar: + st.markdown("### 📦 Catalog Management") + + # Check catalog status + if st.session_state.catalog_info is None: + catalog_info = st.session_state.ui_service.get_catalog_info() + st.session_state.catalog_info = catalog_info + st.session_state.catalog_loaded = catalog_info.get("loaded", False) + + if st.session_state.catalog_loaded: + info = st.session_state.catalog_info + st.success(f"✅ {info.get('product_count', 0)} Products Loaded") + st.caption(f"Categories: {len(info.get('categories', []))}") + + if st.button("🗑️ Clear Catalog", use_container_width=True): + result = st.session_state.ui_service.clear_product_catalog() + if not result.get("error"): + st.session_state.catalog_loaded = False + st.session_state.catalog_info = None + st.session_state.all_products = [] + st.session_state.filtered_products = [] + st.rerun() + else: + st.warning("⚠️ No catalog loaded") + + # File upload + st.markdown("---") + st.markdown("**Upload New Catalog**") + uploaded_file = st.file_uploader( + "Select File", + type=["csv", "json", "xlsx"], + help="Upload CSV/JSON with products" + ) + + if uploaded_file is not None: + if st.button("📤 Upload", type="primary", use_container_width=True): + with st.spinner("Processing..."): + file_data = uploaded_file.read() + result = st.session_state.ui_service.upload_product_catalog( + file_data, uploaded_file.name + ) + + if result.get("error"): + st.error(f"Upload failed") + else: + job_id = result.get("job_id") + + # Auto-confirm mapping + if result.get("requires_confirmation"): + suggested = result.get("suggested_mapping", {}) + st.session_state.ui_service.confirm_product_mapping( + job_id, "Products", suggested + ) + + # Monitor processing + progress_bar = st.progress(0) + max_attempts = 30 + + for attempt in range(max_attempts): + status = st.session_state.ui_service.get_product_ingestion_status(job_id) + if status.get("status") == "complete": + progress_bar.progress(100) + st.session_state.catalog_loaded = True + st.session_state.catalog_info = None + st.session_state.all_products = [] + + # Reload retrieval indexes + with st.spinner("Reloading search indexes..."): + st.session_state.ui_service.reload_retrieval_indexes() + + time.sleep(1) + st.rerun() + break + elif status.get("status") == "failed": + st.error("Processing failed") + break + else: + progress = min(10 + (attempt * 80 // max_attempts), 90) + progress_bar.progress(progress) + time.sleep(2) + + +def render_ecommerce_store(): + """ + Render main e-commerce product display. + + Shows product grid, search bar, and empty states. + """ + + if not st.session_state.catalog_loaded: + st.markdown(''' +
+
🛍️
+
No Products Available
+
Upload a product catalog from the sidebar to get started
+
+ ''', unsafe_allow_html=True) + return + + # Load all products if not already loaded + if not st.session_state.all_products: + with st.spinner("Loading products..."): + products = st.session_state.ui_service.get_all_products(limit=100) + st.session_state.all_products = products + st.session_state.filtered_products = products + + # Search and filter bar + render_search_bar() + + # Display products + products_to_show = st.session_state.filtered_products if st.session_state.filtered_products else st.session_state.all_products + + if not products_to_show: + st.info("🔍 No products found matching your search. Try different keywords or filters.") + return + + # Results header + st.markdown(f'

{len(products_to_show)} Products

', unsafe_allow_html=True) + + # Product grid + render_product_grid(products_to_show) + + +def render_search_bar(): + """ + Render search bar with filters. + + Handles text input, search button, clear button, and active filter display. + """ + st.markdown('
', unsafe_allow_html=True) + + # Search input + col1, col2, col3 = st.columns([6, 1, 1]) + + with col1: + query = st.text_input( + "Search", + placeholder="Search for products...", + key="search_input", + label_visibility="collapsed" + ) + + with col2: + search_clicked = st.button("🔍 Search", type="primary", use_container_width=True) + + with col3: + clear_clicked = st.button("Clear", use_container_width=True) + + # Process search + if search_clicked and query.strip(): + st.session_state.search_query = query + with st.spinner("Searching..."): + results = st.session_state.ui_service.search_products(query, limit=100) + if not results.get("error"): + st.session_state.filtered_products = results.get("results", []) + st.session_state.applied_filters = results.get("query_interpretation", {}).get("extracted_filters", {}) + st.rerun() + + # Clear search + if clear_clicked: + st.session_state.search_query = "" + st.session_state.filtered_products = st.session_state.all_products + st.session_state.applied_filters = {} + st.rerun() + + # Show active filters + if st.session_state.search_query or st.session_state.applied_filters: + st.markdown('
', unsafe_allow_html=True) + + if st.session_state.search_query: + st.markdown(f'🔍 "{st.session_state.search_query}"', unsafe_allow_html=True) + + filters = st.session_state.applied_filters + if filters.get("price_max"): + st.markdown(f'💰 Under ${filters["price_max"]}', unsafe_allow_html=True) + if filters.get("price_min"): + st.markdown(f'💰 Over ${filters["price_min"]}', unsafe_allow_html=True) + if filters.get("rating_min"): + st.markdown(f'⭐ {filters["rating_min"]}+ stars', unsafe_allow_html=True) + if filters.get("categories"): + for cat in filters["categories"]: + st.markdown(f'📂 {cat}', unsafe_allow_html=True) + + st.markdown('
', unsafe_allow_html=True) + + st.markdown('
', unsafe_allow_html=True) + + +def render_product_grid(products: List[Dict[str, Any]]): + """ + Render products in a responsive grid. + + Args: + products (List[Dict[str, Any]]): List of product dictionaries to display. + """ + # Display 4 products per row + cols_per_row = 4 + + for i in range(0, len(products), cols_per_row): + cols = st.columns(cols_per_row) + for j, col in enumerate(cols): + if i + j < len(products): + with col: + render_product_card(products[i + j]) + + +def render_product_card(product: Dict[str, Any]): + """ + Render a single product card with image. + + Starts HTML block for product card styling. + + Args: + product (Dict[str, Any]): Product data. + """ + # Extract product data + name = product.get("name", "Unknown Product") + price = product.get("price") + rating = product.get("rating") + review_count = product.get("review_count", 0) + category = product.get("category", "") + image_url = product.get("image_url") or product.get("metadata", {}).get("image_url", "") + brand = product.get("brand") or product.get("metadata", {}).get("brand", "") + + # Fallback image if none provided + if not image_url: + image_url = "https://via.placeholder.com/400x400/e5e7eb/6b7280?text=No+Image" + + # Rating stars + stars_filled = int(rating) if rating else 0 + stars_empty = 5 - stars_filled + stars_html = "★" * stars_filled + "☆" * stars_empty + + # Product card HTML + card_html = f""" +
+ {name} +
+
{name}
+
+ {stars_html} + ({review_count:,}) +
+
${price:.2f}
+
{category}
+
+
+ """ + + st.markdown(card_html, unsafe_allow_html=True) + + +def main(): + """ + Main application entry point. + + Initializes session state, routing, and proper page layout based on system mode. + """ + # Keycloak Login + keycloak_config = { + "url": os.getenv("KEYCLOAK_URL", os.getenv("BASE_URL", "http://localhost:8080")), + "realm": os.getenv("KEYCLOAK_REALM", "master"), + "client_id": os.getenv("KEYCLOAK_CLIENT_ID", "api") + } + + # keycloak = login( + # url=keycloak_config["url"], + # realm=keycloak_config["realm"], + # client_id=keycloak_config["client_id"], + # init_options={'checkLoginIframe': False} + # ) + + # if not keycloak.authenticated: + # st.warning("Please login to access the system.") + # st.stop() + + initialize_session_state() + + # Note: Keycloak authentication is handled at the service level (embedding, llm, etc.) + # The UI communicates with services through the gateway without needing to pass tokens + + # Get current system mode + if "system_mode_initialized" not in st.session_state: + current_mode = st.session_state.ui_service.get_system_mode() + st.session_state.system_mode = current_mode + st.session_state.system_mode_initialized = True + + render_header() + + st.markdown('
', unsafe_allow_html=True) + + if st.session_state.system_mode == "product": + # E-commerce mode: sidebar + full-width store + render_catalog_sidebar() + render_ecommerce_store() + else: + # Document mode: two-column layout + col1, col2 = st.columns([1, 2], gap="medium") + with col1: + render_upload_panel() + with col2: + render_chat_panel() + + +if __name__ == "__main__": + main() + diff --git a/sample_solutions/HybridSearch/ui/config.py b/sample_solutions/HybridSearch/ui/config.py new file mode 100644 index 00000000..7670dedf --- /dev/null +++ b/sample_solutions/HybridSearch/ui/config.py @@ -0,0 +1,37 @@ +""" +Configuration for UI Service +""" +from pydantic_settings import BaseSettings +from pathlib import Path + + +class Settings(BaseSettings): + # Service URLs (use localhost for local dev, gateway for Docker) + gateway_service_url: str = "http://localhost:8000" + + # UI Configuration + ui_title: str = "InsightMapper Lite" + ui_page_icon: str = "📚" + ui_layout: str = "wide" + + # Feature flags + enable_debug_mode: bool = True + enable_document_upload: bool = True + enable_query_history: bool = True + + # Display settings + max_results_display: int = 5 + show_confidence_scores: bool = True + show_source_preview: bool = True + + # Logging + log_level: str = "INFO" + + class Config: + env_file = Path(__file__).parent.parent / ".env" + case_sensitive = False + extra = "ignore" + + +settings = Settings() + diff --git a/sample_solutions/HybridSearch/ui/public/citations.png b/sample_solutions/HybridSearch/ui/public/citations.png new file mode 100644 index 00000000..9f20587e Binary files /dev/null and b/sample_solutions/HybridSearch/ui/public/citations.png differ diff --git a/sample_solutions/HybridSearch/ui/public/product_catalog.png b/sample_solutions/HybridSearch/ui/public/product_catalog.png new file mode 100644 index 00000000..74274f52 Binary files /dev/null and b/sample_solutions/HybridSearch/ui/public/product_catalog.png differ diff --git a/sample_solutions/HybridSearch/ui/public/rag_chatbot.png b/sample_solutions/HybridSearch/ui/public/rag_chatbot.png new file mode 100644 index 00000000..1f926b02 Binary files /dev/null and b/sample_solutions/HybridSearch/ui/public/rag_chatbot.png differ diff --git a/sample_solutions/HybridSearch/ui/requirements.txt b/sample_solutions/HybridSearch/ui/requirements.txt new file mode 100644 index 00000000..a17ef62a --- /dev/null +++ b/sample_solutions/HybridSearch/ui/requirements.txt @@ -0,0 +1,9 @@ +streamlit==1.29.0 +requests==2.31.0 +httpx==0.28.1 +pandas==2.1.4 +plotly==5.18.0 +pydantic==2.5.0 +pydantic-settings==2.1.0 +python-dotenv==1.0.0 +streamlit-keycloak==1.1.1