Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/stress-test-mcp-server.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
KAI_DB_DSN: "postgresql+asyncpg://kai_user:kai_password@localhost:5432/kai_test_db"
KAI_LLM_PARAMS: '{"model": "fake", "responses": ["Test response"]}'
MCP_SERVER_URL: "http://localhost:8000"
NUM_CONCURRENT_CLIENTS: ${{ github.event.inputs.num_clients || '100' }}
NUM_CONCURRENT_CLIENTS: ${{ github.event.inputs.num_clients || '200' }}
run: |
echo "Starting MCP server connected to PostgreSQL..."
uv run python -m kai_mcp_solution_server --transport streamable-http --host 0.0.0.0 --port 8000 &
Expand Down
1 change: 1 addition & 0 deletions kai_mcp_solution_server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ run-podman: build
podman-postgres: build
@echo "Starting MCP solution server with PostgreSQL using podman-compose..."
@if [ -z "$(KAI_LLM_PARAMS)" ]; then echo "Error: KAI_LLM_PARAMS is required"; exit 1; fi
@cd tools/deploy && \
IMAGE=$(IMAGE) KAI_LLM_PARAMS='$(KAI_LLM_PARAMS)' MOUNT_PATH='$(MOUNT_PATH)' \
podman-compose up --force-recreate

Expand Down
20 changes: 19 additions & 1 deletion kai_mcp_solution_server/src/kai_mcp_solution_server/db/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
String,
event,
func,
text,
)
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
Expand Down Expand Up @@ -110,6 +111,22 @@ async def ensure_tables_exist(engine: AsyncEngine) -> None:
await conn.run_sync(Base.metadata.create_all)


async def kill_idle_connections(engine: AsyncEngine) -> None:
"""Kill all idle connections from this application to the database."""
async with engine.begin() as conn:
await conn.execute(
text(
"""
SELECT pg_terminate_backend(pid)
FROM pg_stat_activity
WHERE application_name = 'kai-solution-server'
AND state = 'idle'
AND pid != pg_backend_pid()
"""
)
)

Comment on lines 114 to 136
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard Postgres-only termination and return affected count

Kill query is Postgres-specific; on SQLite/MySQL it will fail. Also, returning how many sessions were terminated is useful for observability.

-async def kill_idle_connections(engine: AsyncEngine) -> None:
-    """Kill all idle connections from this application to the database."""
-    async with engine.begin() as conn:
-        await conn.execute(
-            text(
-                """
-                SELECT pg_terminate_backend(pid)
-                FROM pg_stat_activity
-                WHERE application_name = 'kai-solution-server'
-                AND state = 'idle'
-                AND pid != pg_backend_pid()
-                """
-            )
-        )
+async def kill_idle_connections(engine: AsyncEngine) -> int:
+    """Kill all idle connections from this application to the database (Postgres only). Returns number terminated."""
+    if getattr(engine, "dialect", None) is None or engine.dialect.name != "postgresql":
+        return 0
+    async with engine.begin() as conn:
+        res = await conn.execute(
+            text(
+                """
+                SELECT pg_terminate_backend(pid)
+                FROM pg_stat_activity
+                WHERE application_name = 'kai-solution-server'
+                  AND state = 'idle'
+                  AND pid != pg_backend_pid()
+                """
+            )
+        )
+        # rowcount reflects number of rows returned/affected
+        return getattr(res, "rowcount", 0) or 0
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async def kill_idle_connections(engine: AsyncEngine) -> None:
"""Kill all idle connections from this application to the database."""
async with engine.begin() as conn:
await conn.execute(
text(
"""
SELECT pg_terminate_backend(pid)
FROM pg_stat_activity
WHERE application_name = 'kai-solution-server'
AND state = 'idle'
AND pid != pg_backend_pid()
"""
)
)
async def kill_idle_connections(engine: AsyncEngine) -> int:
"""Kill all idle connections from this application to the database (Postgres only). Returns number terminated."""
if getattr(engine, "dialect", None) is None or engine.dialect.name != "postgresql":
return 0
async with engine.begin() as conn:
res = await conn.execute(
text(
"""
SELECT pg_terminate_backend(pid)
FROM pg_stat_activity
WHERE application_name = 'kai-solution-server'
AND state = 'idle'
AND pid != pg_backend_pid()
"""
)
)
# rowcount reflects number of rows returned/affected
return getattr(res, "rowcount", 0) or 0


async def get_async_engine(url: URL | str) -> AsyncEngine:
# Convert to string if URL object
url_str = str(url)
Expand All @@ -133,10 +150,11 @@ async def get_async_engine(url: URL | str) -> AsyncEngine:
url,
pool_size=20, # Base connections maintained in pool
max_overflow=80, # Additional connections created as needed (total max = 100)
pool_timeout=30, # Timeout waiting for a connection from pool
pool_timeout=60, # Timeout waiting for a connection from pool
pool_recycle=3600, # Recycle connections after 1 hour
pool_pre_ping=True, # Test connections before using
echo_pool=False, # Set to True for debugging connection pool
pool_reset_on_return="rollback", # Reset connections on return to pool
)

@event.listens_for(engine.sync_engine, "connect")
Expand Down
180 changes: 153 additions & 27 deletions kai_mcp_solution_server/src/kai_mcp_solution_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from pydantic import BaseModel, model_validator
from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict
from sqlalchemy import URL, and_, make_url, or_, select
from sqlalchemy.ext.asyncio import async_sessionmaker
from sqlalchemy.exc import DBAPIError, IntegrityError, OperationalError
from sqlalchemy.exc import TimeoutError as SQLAlchemyTimeoutError
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker

from kai_mcp_solution_server.analyzer_types import ExtendedIncident
from kai_mcp_solution_server.ast_diff.parser import Language, extract_ast_info
Expand All @@ -26,6 +28,7 @@
DBSolution,
DBViolation,
get_async_engine,
kill_idle_connections,
)
from kai_mcp_solution_server.db.python_objects import (
SolutionFile,
Expand All @@ -36,6 +39,54 @@
)


def with_db_recovery(func):
"""Decorator to execute database operations with automatic recovery on connection errors.
Uses a semaphore to limit concurrent DB operations and prevent pool exhaustion.
Implements exponential backoff with retry on connection errors.
"""

async def wrapper(kai_ctx: KaiSolutionServerContext, *args, **kwargs):
if _SharedResources.db_semaphore is None:
raise RuntimeError("Database semaphore not initialized")

# Semaphore ensures we don't overwhelm the connection pool
async with _SharedResources.db_semaphore:
max_retries = 3
base_delay = 0.1 # 100ms base delay

for attempt in range(max_retries):
try:
return await func(kai_ctx, *args, **kwargs)
except IntegrityError:
raise
except SQLAlchemyTimeoutError as e:

Check failure on line 63 in kai_mcp_solution_server/src/kai_mcp_solution_server/server.py

View workflow job for this annotation

GitHub Actions / Trunk Check

ruff(F841)

[new] Local variable `e` is assigned to but never used
if attempt < max_retries - 1:
delay = base_delay * (2**attempt) # Exponential backoff
log(
f"Connection pool timeout (attempt {attempt + 1}), retrying in {delay}s..."
)
await asyncio.sleep(delay)
await _recover_from_db_error()
else:
log(f"Connection pool exhausted after {max_retries} attempts")
raise
except (DBAPIError, OperationalError) as e:
if attempt < max_retries - 1:
delay = base_delay * (2**attempt)
log(
f"Database error (attempt {attempt + 1}): {e}, retrying in {delay}s..."
)
await asyncio.sleep(delay)
await _recover_from_db_error()
await kai_ctx.create()
else:
log(f"Database error after {max_retries} attempts, giving up")
raise

return wrapper


class SolutionServerSettings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="kai_")

Expand Down Expand Up @@ -108,37 +159,109 @@
return data


class KaiSolutionServerContext:
def __init__(self, settings: SolutionServerSettings) -> None:
self.settings = settings
self.lock = asyncio.Lock()
class _SharedResources:
"""Global shared resources initialized once at module level."""

engine: AsyncEngine | None = None
session_maker: async_sessionmaker | None = None
model: BaseChatModel | None = None
initialization_lock: asyncio.Lock | None = None
initialized: bool = False
# Semaphore to limit concurrent DB operations (prevent connection pool exhaustion)
db_semaphore: asyncio.Semaphore | None = None
max_concurrent_ops: int = 80 # Allow up to 80 concurrent DB operations


async def _initialize_shared_resources() -> None:
"""Initialize shared resources once, protected by a lock."""
if _SharedResources.initialization_lock is None:
_SharedResources.initialization_lock = asyncio.Lock()

async with _SharedResources.initialization_lock:
if _SharedResources.initialized:
return

async def create(self) -> None:
from kai_mcp_solution_server.db.dao import Base

self.engine = await get_async_engine(self.settings.db_dsn)
log(
"Initializing shared database engine and model (once for all connections)..."
)
settings = SolutionServerSettings()

# Initialize semaphore to limit concurrent DB operations
_SharedResources.db_semaphore = asyncio.Semaphore(
_SharedResources.max_concurrent_ops
)
log(
f"DB operation semaphore initialized with limit: {_SharedResources.max_concurrent_ops}"
)

# Ensure tables exist (safe - only creates if not already there)
async with self.engine.begin() as conn:
# Initialize database engine
_SharedResources.engine = await get_async_engine(settings.db_dsn)
async with _SharedResources.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

self.session_maker = async_sessionmaker(
bind=self.engine, expire_on_commit=False
_SharedResources.session_maker = async_sessionmaker(
bind=_SharedResources.engine, expire_on_commit=False
)
self.model: BaseChatModel

if self.settings.llm_params is None:
# Initialize model
if settings.llm_params is None:
raise ValueError("LLM parameters must be provided in the settings.")
elif self.settings.llm_params.get("model") == "fake":
llm_params = self.settings.llm_params.copy()
elif settings.llm_params.get("model") == "fake":
llm_params = settings.llm_params.copy()
llm_params.pop("model", None)
if "responses" not in llm_params:
llm_params["responses"] = [
"fake response",
]
self.model = FakeListChatModel(**llm_params)
llm_params["responses"] = ["fake response"]
_SharedResources.model = FakeListChatModel(**llm_params)
else:
self.model = init_chat_model(**self.settings.llm_params)
_SharedResources.model = init_chat_model(**settings.llm_params)

_SharedResources.initialized = True
log("Shared resources initialized successfully")


async def _recover_from_db_error() -> None:
"""Recover from database errors by killing idle connections or recreating engine."""
if _SharedResources.engine is not None:
log("Recovering from database error - killing idle connections...")
try:
await kill_idle_connections(_SharedResources.engine)
log("Successfully killed idle connections")
except Exception as e:
log(f"Failed to kill idle connections: {e}")
log("Disposing and recreating engine...")
await _SharedResources.engine.dispose()
_SharedResources.initialized = False
await _initialize_shared_resources()


class KaiSolutionServerContext:
"""Per-connection context that references shared resources."""

def __init__(self, settings: SolutionServerSettings) -> None:
self.settings = settings
self.lock = asyncio.Lock()
# References to shared resources (set in create())
self.engine: AsyncEngine | None = None
self.session_maker: async_sessionmaker | None = None
self.model: BaseChatModel | None = None

async def create(self) -> None:
"""Initialize shared resources if needed and reference them."""
await _initialize_shared_resources()

if _SharedResources.engine is None:
raise RuntimeError("Database engine failed to initialize")
if _SharedResources.session_maker is None:
raise RuntimeError("Session maker failed to initialize")
if _SharedResources.model is None:
raise RuntimeError("Model failed to initialize")

log(f"Connection using shared engine: {id(_SharedResources.engine)}")
self.engine = _SharedResources.engine
self.session_maker = _SharedResources.session_maker
self.model = _SharedResources.model


@asynccontextmanager
Expand All @@ -157,15 +280,8 @@

yield ctx
except Exception as e:

log(f"Error in lifespan: {traceback.format_exc()}")
raise e
finally:
# Clean up database connections when client disconnects
if "ctx" in locals() and hasattr(ctx, "engine"):
log("Disposing database engine...")
await ctx.engine.dispose()
log("Database engine disposed")


mcp: FastMCP[KaiSolutionServerContext] = FastMCP(
Expand All @@ -178,6 +294,7 @@
solution_id: int


@with_db_recovery
async def create_incident(
kai_ctx: KaiSolutionServerContext,
client_id: str,
Expand Down Expand Up @@ -265,6 +382,7 @@
return results


@with_db_recovery
async def create_solution(
kai_ctx: KaiSolutionServerContext,
client_id: str,
Expand Down Expand Up @@ -394,6 +512,7 @@
)


@with_db_recovery
async def generate_hint_v1(
kai_ctx: KaiSolutionServerContext,
client_id: str,
Expand Down Expand Up @@ -458,6 +577,7 @@
await session.flush()


@with_db_recovery
async def generate_hint_v2(
kai_ctx: KaiSolutionServerContext,
client_id: str,
Expand Down Expand Up @@ -536,6 +656,7 @@
await session.flush()


@with_db_recovery
async def generate_hint_v3(
kai_ctx: KaiSolutionServerContext,
client_id: str,
Expand Down Expand Up @@ -626,6 +747,7 @@
await session.flush()


@with_db_recovery
async def delete_solution(
kai_ctx: KaiSolutionServerContext,
client_id: str,
Expand Down Expand Up @@ -665,6 +787,7 @@
hint_id: int


@with_db_recovery
async def get_best_hint(
kai_ctx: KaiSolutionServerContext,
ruleset_name: str,
Expand Down Expand Up @@ -727,6 +850,7 @@
unknown_solutions: int = 0


@with_db_recovery
async def get_success_rate(
kai_ctx: KaiSolutionServerContext,
violation_ids: list[ViolationID],
Expand Down Expand Up @@ -809,6 +933,7 @@
)


@with_db_recovery
async def accept_file(
kai_ctx: KaiSolutionServerContext,
client_id: str,
Expand Down Expand Up @@ -897,6 +1022,7 @@
)


@with_db_recovery
async def reject_file(
kai_ctx: KaiSolutionServerContext,
client_id: str,
Expand Down
Loading
Loading