diff --git a/.github/workflows/stress-test-mcp-server.yml b/.github/workflows/stress-test-mcp-server.yml index b063ace5..33387e0a 100644 --- a/.github/workflows/stress-test-mcp-server.yml +++ b/.github/workflows/stress-test-mcp-server.yml @@ -66,7 +66,7 @@ jobs: KAI_DB_DSN: "postgresql+asyncpg://kai_user:kai_password@localhost:5432/kai_test_db" KAI_LLM_PARAMS: '{"model": "fake", "responses": ["Test response"]}' MCP_SERVER_URL: "http://localhost:8000" - NUM_CONCURRENT_CLIENTS: ${{ github.event.inputs.num_clients || '100' }} + NUM_CONCURRENT_CLIENTS: ${{ github.event.inputs.num_clients || '200' }} run: | echo "Starting MCP server connected to PostgreSQL..." uv run python -m kai_mcp_solution_server --transport streamable-http --host 0.0.0.0 --port 8000 & diff --git a/kai_mcp_solution_server/Makefile b/kai_mcp_solution_server/Makefile index 5c5e8a92..6ddfa0a3 100644 --- a/kai_mcp_solution_server/Makefile +++ b/kai_mcp_solution_server/Makefile @@ -103,6 +103,7 @@ run-podman: build podman-postgres: build @echo "Starting MCP solution server with PostgreSQL using podman-compose..." @if [ -z "$(KAI_LLM_PARAMS)" ]; then echo "Error: KAI_LLM_PARAMS is required"; exit 1; fi + @cd tools/deploy && \ IMAGE=$(IMAGE) KAI_LLM_PARAMS='$(KAI_LLM_PARAMS)' MOUNT_PATH='$(MOUNT_PATH)' \ podman-compose up --force-recreate diff --git a/kai_mcp_solution_server/src/kai_mcp_solution_server/db/dao.py b/kai_mcp_solution_server/src/kai_mcp_solution_server/db/dao.py index 469432ba..827b45e4 100755 --- a/kai_mcp_solution_server/src/kai_mcp_solution_server/db/dao.py +++ b/kai_mcp_solution_server/src/kai_mcp_solution_server/db/dao.py @@ -16,6 +16,7 @@ String, event, func, + text, ) from sqlalchemy.engine.reflection import Inspector from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine @@ -110,6 +111,30 @@ async def ensure_tables_exist(engine: AsyncEngine) -> None: await conn.run_sync(Base.metadata.create_all) +async def kill_idle_connections(engine: AsyncEngine) -> None: + """Kill all idle connections from this application to the database. + + Note: This is a PostgreSQL-specific operation and will be skipped for other databases. + """ + # Only execute for PostgreSQL databases + if engine.dialect.name != "postgresql": + # Silently skip for non-PostgreSQL databases + return + + async with engine.begin() as conn: + await conn.execute( + text( + """ + SELECT pg_terminate_backend(pid) + FROM pg_stat_activity + WHERE application_name = 'kai-solution-server' + AND state = 'idle' + AND pid != pg_backend_pid() + """ + ) + ) + + async def get_async_engine(url: URL | str) -> AsyncEngine: # Convert to string if URL object url_str = str(url) @@ -133,19 +158,24 @@ async def get_async_engine(url: URL | str) -> AsyncEngine: url, pool_size=20, # Base connections maintained in pool max_overflow=80, # Additional connections created as needed (total max = 100) - pool_timeout=30, # Timeout waiting for a connection from pool + pool_timeout=60, # Timeout waiting for a connection from pool pool_recycle=3600, # Recycle connections after 1 hour pool_pre_ping=True, # Test connections before using echo_pool=False, # Set to True for debugging connection pool + pool_reset_on_return="rollback", # Reset connections on return to pool ) - @event.listens_for(engine.sync_engine, "connect") - def _set_pg_timeouts(dbapi_conn: Any, conn_record: Any) -> None: - cur = dbapi_conn.cursor() - cur.execute("SET idle_session_timeout = '1min'") - cur.execute("SET idle_in_transaction_session_timeout = '1min'") - cur.execute("SET application_name = 'kai-solution-server'") - cur.close() + # Only set PostgreSQL-specific timeouts and application name for PostgreSQL + # Note: We check the dialect name after engine creation + if engine.dialect.name == "postgresql": + + @event.listens_for(engine.sync_engine, "connect") + def _set_pg_timeouts(dbapi_conn: Any, conn_record: Any) -> None: + cur = dbapi_conn.cursor() + cur.execute("SET idle_session_timeout = '1min'") + cur.execute("SET idle_in_transaction_session_timeout = '1min'") + cur.execute("SET application_name = 'kai-solution-server'") + cur.close() return engine diff --git a/kai_mcp_solution_server/src/kai_mcp_solution_server/server.py b/kai_mcp_solution_server/src/kai_mcp_solution_server/server.py index 60e3a4b8..ad6b31be 100644 --- a/kai_mcp_solution_server/src/kai_mcp_solution_server/server.py +++ b/kai_mcp_solution_server/src/kai_mcp_solution_server/server.py @@ -1,11 +1,12 @@ import asyncio +import functools import json import os import sys import traceback -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Callable, Coroutine from contextlib import asynccontextmanager -from typing import Annotated, Any, cast +from typing import Annotated, Any, ParamSpec, TypeVar, cast from fastmcp import Context, FastMCP from langchain.chat_models import init_chat_model @@ -14,7 +15,9 @@ from pydantic import BaseModel, model_validator from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict from sqlalchemy import URL, and_, make_url, or_, select -from sqlalchemy.ext.asyncio import async_sessionmaker +from sqlalchemy.exc import DBAPIError, IntegrityError, OperationalError +from sqlalchemy.exc import TimeoutError as SQLAlchemyTimeoutError +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker from kai_mcp_solution_server.analyzer_types import ExtendedIncident from kai_mcp_solution_server.ast_diff.parser import Language, extract_ast_info @@ -26,6 +29,7 @@ DBSolution, DBViolation, get_async_engine, + kill_idle_connections, ) from kai_mcp_solution_server.db.python_objects import ( SolutionFile, @@ -35,6 +39,66 @@ get_diff, ) +P = ParamSpec("P") +T = TypeVar("T") + + +def with_db_recovery( + func: Callable[..., Coroutine[Any, Any, T]] +) -> Callable[..., Coroutine[Any, Any, T]]: + """Decorator to execute database operations with automatic recovery on connection errors. + + Uses a semaphore to limit concurrent DB operations and prevent pool exhaustion. + Implements exponential backoff with retry on connection errors. + """ + + @functools.wraps(func) + async def wrapper( + kai_ctx: KaiSolutionServerContext, *args: Any, **kwargs: Any + ) -> T: + if _SharedResources.db_semaphore is None: + raise RuntimeError("Database semaphore not initialized") + + # Semaphore ensures we don't overwhelm the connection pool + async with _SharedResources.db_semaphore: + max_retries = 3 + base_delay = 0.1 # 100ms base delay + + for attempt in range(max_retries): + try: + return await func(kai_ctx, *args, **kwargs) + except IntegrityError: + raise + except SQLAlchemyTimeoutError: + if attempt < max_retries - 1: + delay = base_delay * (2**attempt) # Exponential backoff + log( + f"Connection pool timeout (attempt {attempt + 1}), retrying in {delay}s..." + ) + await asyncio.sleep(delay) + await _recover_from_db_error() + else: + log(f"Connection pool exhausted after {max_retries} attempts") + raise + except (DBAPIError, OperationalError) as e: + if attempt < max_retries - 1: + delay = base_delay * (2**attempt) + log( + f"Database error (attempt {attempt + 1}): {e}, retrying in {delay}s..." + ) + await asyncio.sleep(delay) + await _recover_from_db_error() + await kai_ctx.create() + else: + log(f"Database error after {max_retries} attempts, giving up") + raise + + # This should never be reached due to the loop structure, + # but mypy needs an explicit unreachable marker + raise RuntimeError("Unexpected: retry loop completed without returning") + + return wrapper + class SolutionServerSettings(BaseSettings): model_config = SettingsConfigDict(env_prefix="kai_") @@ -108,37 +172,109 @@ def validate_db_dsn(cls, data: Any) -> Any: return data -class KaiSolutionServerContext: - def __init__(self, settings: SolutionServerSettings) -> None: - self.settings = settings - self.lock = asyncio.Lock() +class _SharedResources: + """Global shared resources initialized once at module level.""" + + engine: AsyncEngine | None = None + session_maker: async_sessionmaker[AsyncSession] | None = None + model: BaseChatModel | None = None + initialization_lock: asyncio.Lock | None = None + initialized: bool = False + # Semaphore to limit concurrent DB operations (prevent connection pool exhaustion) + db_semaphore: asyncio.Semaphore | None = None + max_concurrent_ops: int = 80 # Allow up to 80 concurrent DB operations + + +async def _initialize_shared_resources() -> None: + """Initialize shared resources once, protected by a lock.""" + if _SharedResources.initialization_lock is None: + _SharedResources.initialization_lock = asyncio.Lock() + + async with _SharedResources.initialization_lock: + if _SharedResources.initialized: + return - async def create(self) -> None: from kai_mcp_solution_server.db.dao import Base - self.engine = await get_async_engine(self.settings.db_dsn) + log( + "Initializing shared database engine and model (once for all connections)..." + ) + settings = SolutionServerSettings() - # Ensure tables exist (safe - only creates if not already there) - async with self.engine.begin() as conn: + # Initialize semaphore to limit concurrent DB operations + _SharedResources.db_semaphore = asyncio.Semaphore( + _SharedResources.max_concurrent_ops + ) + log( + f"DB operation semaphore initialized with limit: {_SharedResources.max_concurrent_ops}" + ) + + # Initialize database engine + _SharedResources.engine = await get_async_engine(settings.db_dsn) + async with _SharedResources.engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) - self.session_maker = async_sessionmaker( - bind=self.engine, expire_on_commit=False + _SharedResources.session_maker = async_sessionmaker( + bind=_SharedResources.engine, expire_on_commit=False ) - self.model: BaseChatModel - if self.settings.llm_params is None: + # Initialize model + if settings.llm_params is None: raise ValueError("LLM parameters must be provided in the settings.") - elif self.settings.llm_params.get("model") == "fake": - llm_params = self.settings.llm_params.copy() + elif settings.llm_params.get("model") == "fake": + llm_params = settings.llm_params.copy() llm_params.pop("model", None) if "responses" not in llm_params: - llm_params["responses"] = [ - "fake response", - ] - self.model = FakeListChatModel(**llm_params) + llm_params["responses"] = ["fake response"] + _SharedResources.model = FakeListChatModel(**llm_params) else: - self.model = init_chat_model(**self.settings.llm_params) + _SharedResources.model = init_chat_model(**settings.llm_params) + + _SharedResources.initialized = True + log("Shared resources initialized successfully") + + +async def _recover_from_db_error() -> None: + """Recover from database errors by killing idle connections or recreating engine.""" + if _SharedResources.engine is not None: + log("Recovering from database error - killing idle connections...") + try: + await kill_idle_connections(_SharedResources.engine) + log("Successfully killed idle connections") + except Exception as e: + log(f"Failed to kill idle connections: {e}") + log("Disposing and recreating engine...") + await _SharedResources.engine.dispose() + _SharedResources.initialized = False + await _initialize_shared_resources() + + +class KaiSolutionServerContext: + """Per-connection context that references shared resources.""" + + def __init__(self, settings: SolutionServerSettings) -> None: + self.settings = settings + self.lock = asyncio.Lock() + # References to shared resources (set in create()) + self.engine: AsyncEngine | None = None + self.session_maker: async_sessionmaker[AsyncSession] | None = None + self.model: BaseChatModel | None = None + + async def create(self) -> None: + """Initialize shared resources if needed and reference them.""" + await _initialize_shared_resources() + + if _SharedResources.engine is None: + raise RuntimeError("Database engine failed to initialize") + if _SharedResources.session_maker is None: + raise RuntimeError("Session maker failed to initialize") + if _SharedResources.model is None: + raise RuntimeError("Model failed to initialize") + + log(f"Connection using shared engine: {id(_SharedResources.engine)}") + self.engine = _SharedResources.engine + self.session_maker = _SharedResources.session_maker + self.model = _SharedResources.model @asynccontextmanager @@ -157,15 +293,8 @@ async def kai_solution_server_lifespan( yield ctx except Exception as e: - log(f"Error in lifespan: {traceback.format_exc()}") raise e - finally: - # Clean up database connections when client disconnects - if "ctx" in locals() and hasattr(ctx, "engine"): - log("Disposing database engine...") - await ctx.engine.dispose() - log("Database engine disposed") mcp: FastMCP[KaiSolutionServerContext] = FastMCP( @@ -178,11 +307,14 @@ class CreateIncidentResult(BaseModel): solution_id: int +@with_db_recovery async def create_incident( kai_ctx: KaiSolutionServerContext, client_id: str, extended_incident: ExtendedIncident, ) -> int: + if kai_ctx.session_maker is None: + raise RuntimeError("Session maker not initialized") async with kai_ctx.session_maker.begin() as session: violation_stmt = select(DBViolation).where( DBViolation.ruleset_name == extended_incident.ruleset_name, @@ -265,6 +397,7 @@ async def tool_create_multiple_incidents( return results +@with_db_recovery async def create_solution( kai_ctx: KaiSolutionServerContext, client_id: str, @@ -280,6 +413,8 @@ async def create_solution( if used_hint_ids is None: used_hint_ids = [] + if kai_ctx.session_maker is None: + raise RuntimeError("Session maker not initialized") async with kai_ctx.session_maker.begin() as session: incident_ids_cond = [ and_(DBIncident.id == incident_id, DBIncident.client_id == client_id) @@ -394,10 +529,13 @@ async def tool_create_solution( ) +@with_db_recovery async def generate_hint_v1( kai_ctx: KaiSolutionServerContext, client_id: str, ) -> None: + if kai_ctx.session_maker is None: + raise RuntimeError("Session maker not initialized") async with kai_ctx.session_maker.begin() as session: solutions_stmt = select(DBSolution).where( DBSolution.client_id == client_id, @@ -440,6 +578,8 @@ async def generate_hint_v1( log(f"Generating hint for client {client_id} with prompt:\n{prompt}") + if kai_ctx.model is None: + raise RuntimeError("Model not initialized") response = await kai_ctx.model.ainvoke(prompt) log(f"Generated hint: {response.content}") @@ -458,11 +598,14 @@ async def generate_hint_v1( await session.flush() +@with_db_recovery async def generate_hint_v2( kai_ctx: KaiSolutionServerContext, client_id: str, ) -> None: # print(f"Generating hint for client {client_id}", file=sys.stderr) + if kai_ctx.session_maker is None: + raise RuntimeError("Session maker not initialized") async with kai_ctx.session_maker.begin() as session: solutions_stmt = select(DBSolution).where( DBSolution.client_id == client_id, @@ -518,6 +661,8 @@ async def generate_hint_v2( # print(f"Generating hint for client {client_id} with prompt:\n{prompt}", file=sys.stderr) + if kai_ctx.model is None: + raise RuntimeError("Model not initialized") response = await kai_ctx.model.ainvoke(prompt) # print(f"Generated hint: {response.content}", file=sys.stderr) @@ -536,6 +681,7 @@ async def generate_hint_v2( await session.flush() +@with_db_recovery async def generate_hint_v3( kai_ctx: KaiSolutionServerContext, client_id: str, @@ -543,6 +689,8 @@ async def generate_hint_v3( """ Generate hints for accepted solutions using improved prompt format with better structure. """ + if kai_ctx.session_maker is None: + raise RuntimeError("Session maker not initialized") async with kai_ctx.session_maker.begin() as session: solutions_stmt = select(DBSolution).where( DBSolution.client_id == client_id, @@ -610,6 +758,8 @@ async def generate_hint_v3( ast_diff_str = "\n\n".join(str(a) for a in ast_diffs if a is not None) prompt += f"AST Diff:\n{ast_diff_str}\n\n" + if kai_ctx.model is None: + raise RuntimeError("Model not initialized") response = await kai_ctx.model.ainvoke(prompt) hint = DBHint( @@ -626,11 +776,14 @@ async def generate_hint_v3( await session.flush() +@with_db_recovery async def delete_solution( kai_ctx: KaiSolutionServerContext, client_id: str, solution_id: int, ) -> bool: + if kai_ctx.session_maker is None: + raise RuntimeError("Session maker not initialized") async with kai_ctx.session_maker.begin() as session: sln = await session.get(DBSolution, solution_id) if sln is None: @@ -665,11 +818,14 @@ class GetBestHintResult(BaseModel): hint_id: int +@with_db_recovery async def get_best_hint( kai_ctx: KaiSolutionServerContext, ruleset_name: str, violation_name: str, ) -> GetBestHintResult | None: + if kai_ctx.session_maker is None: + raise RuntimeError("Session maker not initialized") async with kai_ctx.session_maker.begin() as session: violation_name_stmt = select(DBViolation).where( DBViolation.ruleset_name == ruleset_name, @@ -727,6 +883,7 @@ class SuccessRateMetric(BaseModel): unknown_solutions: int = 0 +@with_db_recovery async def get_success_rate( kai_ctx: KaiSolutionServerContext, violation_ids: list[ViolationID], @@ -736,6 +893,8 @@ async def get_success_rate( if len(violation_ids) == 0: return result + if kai_ctx.session_maker is None: + raise RuntimeError("Session maker not initialized") async with kai_ctx.session_maker.begin() as session: violations_where = or_( # type: ignore[arg-type] and_( @@ -809,11 +968,14 @@ async def tool_get_success_rate( ) +@with_db_recovery async def accept_file( kai_ctx: KaiSolutionServerContext, client_id: str, solution_file: SolutionFile, ) -> None: + if kai_ctx.session_maker is None: + raise RuntimeError("Session maker not initialized") async with kai_ctx.session_maker.begin() as session: solutions_stmt = select(DBSolution).where(DBSolution.client_id == client_id) solutions = (await session.execute(solutions_stmt)).scalars().all() @@ -897,11 +1059,14 @@ async def tool_accept_file( ) +@with_db_recovery async def reject_file( kai_ctx: KaiSolutionServerContext, client_id: str, file_uri: str, ) -> None: + if kai_ctx.session_maker is None: + raise RuntimeError("Session maker not initialized") async with kai_ctx.session_maker.begin() as session: solutions_stmt = select(DBSolution).where(DBSolution.client_id == client_id) solutions = (await session.execute(solutions_stmt)).scalars().all() diff --git a/kai_mcp_solution_server/tests/test_multiple_integration.py b/kai_mcp_solution_server/tests/test_multiple_integration.py index 4c557fcf..7cd515ca 100644 --- a/kai_mcp_solution_server/tests/test_multiple_integration.py +++ b/kai_mcp_solution_server/tests/test_multiple_integration.py @@ -1,6 +1,4 @@ import asyncio -import concurrent -import concurrent.futures import datetime import json import os @@ -398,20 +396,9 @@ async def test_multiple_users(self) -> None: ) # Don't set KAI_LLM_PARAMS for external server - it should already be configured - def run_async_in_thread(fn, *args, **kwargs): - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - result = loop.run_until_complete(fn(*args, **kwargs)) - return result - finally: - loop.close() - - async def client_task(client_id: str) -> None: + async def client_task( + client_id: str, ready_event: asyncio.Event, release_event: asyncio.Event + ) -> None: print(f"[Client {client_id}] starting") apply_ssl_bypass() @@ -439,8 +426,8 @@ async def client_task(client_id: str) -> None: "extended_incident": { "uri": f"file://test/file_{client_id}.java", "message": f"Test issue for client {client_id}", - "ruleset_name": f"test-ruleset-{client_id % 10}", # Share some rulesets - "violation_name": f"test-violation-{client_id % 5}", # Share some violations + "ruleset_name": f"test-ruleset-{client_id}", # Unique per client for predictable tests + "violation_name": f"test-violation-{client_id}", # Unique per client for predictable tests "violation_category": "potential", "code_snip": "// test code", "line_number": 42, @@ -512,8 +499,8 @@ async def client_task(client_id: str) -> None: { "violation_ids": [ { - "ruleset_name": f"test-ruleset-{client_id % 10}", - "violation_name": f"test-violation-{client_id % 5}", + "ruleset_name": f"test-ruleset-{client_id}", + "violation_name": f"test-violation-{client_id}", } ] }, @@ -568,8 +555,8 @@ async def client_task(client_id: str) -> None: { "violation_ids": [ { - "ruleset_name": f"test-ruleset-{client_id % 10}", - "violation_name": f"test-violation-{client_id % 5}", + "ruleset_name": f"test-ruleset-{client_id}", + "violation_name": f"test-violation-{client_id}", } ] }, @@ -617,8 +604,18 @@ async def client_task(client_id: str) -> None: print( f"[Client {client_id}] āœ“ All operations completed successfully" ) + + # Signal that this client is ready and wait for release + ready_event.set() + print( + f"[Client {client_id}] waiting for all clients to complete..." + ) + await release_event.wait() + print(f"[Client {client_id}] released, closing connection") + except Exception as e: print(f"[Client {client_id}] ERROR: {e}") + ready_event.set() # Still signal ready even on error raise # Re-raise to fail the test # External server should already be running @@ -626,34 +623,53 @@ async def client_task(client_id: str) -> None: NUM_TASKS = int(os.environ.get("NUM_CONCURRENT_CLIENTS", "30")) print(f"Testing with {NUM_TASKS} concurrent clients") - with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_TASKS) as executor: - # Submit each task to the thread pool and store the Future objects. - # The executor will call run_async_in_thread for each task ID. - futures = { - executor.submit(run_async_in_thread, client_task, i): i - for i in range(1, NUM_TASKS + 1) - } + # Create events for synchronization + ready_events = [asyncio.Event() for _ in range(NUM_TASKS)] + release_event = asyncio.Event() - # Use as_completed() to process results as they become available. - # Fail fast - if any task fails, immediately fail the whole test - for future in concurrent.futures.as_completed(futures): - task_id = futures[future] - try: - result = future.result() - print( - f"[Main] received result for Task {task_id}: {result}", - flush=True, - ) - except Exception as exc: - # Fail immediately with detailed error information - self.fail(f"Task {task_id} failed: {exc}") + # Launch all client tasks concurrently + tasks = [ + asyncio.create_task(client_task(i, ready_events[i - 1], release_event)) + for i in range(1, NUM_TASKS + 1) + ] - await asyncio.sleep(10) # wait a moment for all output to be printed + print(f"Waiting for all {NUM_TASKS} clients to complete their operations...") + + # Wait for all clients to signal they're ready (operations complete, connections still open) + await asyncio.gather(*[event.wait() for event in ready_events]) + + print( + f"All {NUM_TASKS} clients have completed operations with connections still open!" + ) + print( + "Holding all connections open for 5 seconds to stress test the connection pool..." + ) + await asyncio.sleep(5) + + # Now release all clients to close their connections + print("Releasing all clients to close connections...") + release_event.set() + + # Wait for all tasks to complete + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Check for exceptions + exceptions = [ + (i + 1, r) for i, r in enumerate(results) if isinstance(r, Exception) + ] + if exceptions: + failure_msg = "\n".join([f"Task {tid}: {exc}" for tid, exc in exceptions]) + self.fail(f"{len(exceptions)}/{NUM_TASKS} tasks failed:\n{failure_msg}") + + await asyncio.sleep(2) # wait a moment for all output to be printed print( f"\nāœ“ All {NUM_TASKS} clients completed successfully with correct results!" ) + # Each client already verified their own data within their task, + # so we've confirmed that all operations succeeded and persisted correctly + # Wait a bit for async hint generation to complete print("\nWaiting for hint generation to complete...") await asyncio.sleep(5) diff --git a/kai_mcp_solution_server/tools/deploy/podman-compose.yml b/kai_mcp_solution_server/tools/deploy/podman-compose.yml new file mode 100644 index 00000000..d0e11c38 --- /dev/null +++ b/kai_mcp_solution_server/tools/deploy/podman-compose.yml @@ -0,0 +1,49 @@ +version: "3.8" + +services: + postgres: + image: docker.io/postgres:16 + container_name: kai-postgres + environment: + POSTGRES_USER: kai_user + POSTGRES_PASSWORD: kai_password + POSTGRES_DB: kai_db + ports: + - 5432:5432 + healthcheck: + test: [CMD-SHELL, pg_isready -U kai_user] + interval: 5s + timeout: 5s + retries: 5 + volumes: + - kai-postgres-data:/var/lib/postgresql/data + - /dev/shm:/dev/shm # Better performance with shared memory + + kai-mcp-server: + image: ${IMAGE:-kai-mcp-solution-server:latest} + container_name: kai-mcp-server + depends_on: + postgres: + condition: service_healthy + ports: + - 8000:8000 + environment: + KAI_DB_DSN: postgresql+asyncpg://kai_user:kai_password@postgres:5432/kai_db # trunk-ignore(checkov/CKV_SECRET_4) + KAI_LLM_PARAMS: ${KAI_LLM_PARAMS} + MOUNT_PATH: ${MOUNT_PATH:-/} + # Pass through API keys and credentials + OPENAI_API_KEY: ${OPENAI_API_KEY} + ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY} + AZURE_OPENAI_API_KEY: ${AZURE_OPENAI_API_KEY} + AZURE_OPENAI_ENDPOINT: ${AZURE_OPENAI_ENDPOINT} + GOOGLE_API_KEY: ${GOOGLE_API_KEY} + AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID} + AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY} + AWS_SESSION_TOKEN: ${AWS_SESSION_TOKEN} + AWS_REGION: ${AWS_REGION} + OLLAMA_HOST: ${OLLAMA_HOST} + # Pass through any other env vars starting with KAI_ + KAI_API_KEY: ${KAI_API_KEY} + +volumes: + kai-postgres-data: {}