Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions kai_mcp_solution_server/tests/mcp_loader_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import ssl
import sys
import traceback
from asyncio.log import logger
from contextlib import AsyncExitStack, asynccontextmanager
from typing import AsyncIterator
Expand Down Expand Up @@ -70,9 +71,19 @@ async def create_http_client(args: MCPClientArgs) -> AsyncIterator[ClientSession
yield session

except Exception as e:

# def recursive_print_exception(exc: Exception, level: int = 0) -> None:
# indent = " " * level
# print(f"{indent}- {str(exc)}")
# if isinstance(exc, ExceptionGroup):
# for sub_exc in exc.exceptions:
# recursive_print_exception(sub_exc, level + 1)

logger.error("HTTP transport error: %s", str(e), exc_info=True)
print(f"x Error with HTTP transport: {e}")
print(f"! Make sure the server is running at {server_url}")
print(f"x Error with HTTP transport: {traceback.format_exc()}")
# if isinstance(e, ExceptionGroup):
# for sub_exc in e.exceptions:
# recursive_print_exception(sub_exc)

# Add specific advice for SSL certificate errors
if (
Expand All @@ -87,7 +98,7 @@ async def create_http_client(args: MCPClientArgs) -> AsyncIterator[ClientSession
print(" 2. Use a valid SSL certificate on the server")
print(" 3. Add the server's certificate to your trusted CA store")

print("! Try using the STDIO transport instead with --transport stdio")
print(f"! Make sure the server is running at {server_url}")


@asynccontextmanager
Expand Down
186 changes: 186 additions & 0 deletions kai_mcp_solution_server/tests/test_multiple_integration.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import asyncio
import concurrent
import concurrent.futures
import datetime
import json
import os
import subprocess

Check notice on line 7 in kai_mcp_solution_server/tests/test_multiple_integration.py

View workflow job for this annotation

GitHub Actions / Trunk Check

bandit(B404)

[new] Consider possible security implications associated with the subprocess module.
import threading
import unittest
from uuid import uuid4

from fastmcp import Client
from mcp import ClientSession
from mcp.types import CallToolResult

from kai_mcp_solution_server.analyzer_types import ExtendedIncident
from kai_mcp_solution_server.server import GetBestHintResult, SuccessRateMetric
from tests.mcp_client import MCPClientArgs
from tests.mcp_loader_script import create_client
from tests.ssl_utils import apply_ssl_bypass

# TODO: The tracebacks from these tests contain horrible impossibly-to-parse output.

Expand Down Expand Up @@ -353,3 +359,183 @@
best_hint = GetBestHintResult(**json.loads(get_best_hint.content[0].text))
print(f"Best hint for {RULESET_NAME_A}/{VIOLATION_NAME_A}: {best_hint}")
self.assertEqual(best_hint.hint, llm_params["responses"][0])

@unittest.skip("Skipping test_solution_server_2 for now")
async def test_solution_server_2(self) -> None:
llm_params = {
"model": "fake",
"responses": [
f"{uuid4()} You should add a smiley face to the file.",
],
}
os.environ["KAI_LLM_PARAMS"] = json.dumps(llm_params)

async with create_client(self.mcp_args) as session:
await session.initialize()

RULESET_NAME_A = f"ruleset-{uuid4()}"
VIOLATION_NAME_A = f"violation-{uuid4()}"
CLIENT_ID_A = str(uuid4())

print()
print("--- Testing modify ---")

create_incident_a = await self.call_tool(
session,
"create_incident",
{
"client_id": CLIENT_ID_A,
"extended_incident": ExtendedIncident(
uri="file://src/file_to_smile.txt",
message="this file needs to have a smiley face",
ruleset_name=RULESET_NAME_A,
violation_name=VIOLATION_NAME_A,
).model_dump(),
},
)
INCIDENT_ID_A = int(create_incident_a.model_dump()["content"][0]["text"])

create_solution_for_incident_a = await self.call_tool(
session,
"create_solution",
{
"client_id": CLIENT_ID_A,
"incident_ids": [INCIDENT_ID_A],
"before": [
{
"uri": "file://src/file_to_smile.txt",
"content": "I am very frowny :(",
}
],
"after": [
{
"uri": "file://src/file_to_smile.txt",
"content": "I am very smiley :)",
}
],
"reasoning": None,
"used_hint_ids": None,
},
)
SOLUTION_FOR_INCIDENT_A_ID = int(

Check failure on line 420 in kai_mcp_solution_server/tests/test_multiple_integration.py

View workflow job for this annotation

GitHub Actions / Trunk Check

ruff(F841)

[new] Local variable `SOLUTION_FOR_INCIDENT_A_ID` is assigned to but never used
create_solution_for_incident_a.model_dump()["content"][0]["text"]
)

async def test_multiple_users(self) -> None:
multiple_user_mcp_args = MCPClientArgs(

Check failure on line 425 in kai_mcp_solution_server/tests/test_multiple_integration.py

View workflow job for this annotation

GitHub Actions / Trunk Check

ruff(F841)

[new] Local variable `multiple_user_mcp_args` is assigned to but never used
transport="http",
host="localhost",
port=8087,
insecure=True,
server_path=self.mcp_args.server_path,
)

os.environ["KAI_LLM_PARAMS"] = json.dumps(
{
"model": "fake",
"responses": [
f"You should add a smiley face to the file.",

Check failure on line 437 in kai_mcp_solution_server/tests/test_multiple_integration.py

View workflow job for this annotation

GitHub Actions / Trunk Check

ruff(F541)

[new] f-string without any placeholders
],
}
)

def stream_output(process: subprocess.Popen) -> None:
try:
assert process.stdout is not None
for line in iter(process.stdout.readline, b""):
print(f"[Server] {line.decode().rstrip()}")
except Exception as e:
print(f"Error while streaming output: {e}")
finally:
process.stdout.close()

def poll_process(process: subprocess.Popen) -> None:
# Check if the process has exited early
if process.poll() is not None:
output = process.stdout.read() if process.stdout else b""
raise RuntimeError(
f"HTTP server process exited prematurely. Output: {output.decode(errors='replace')}"
)

def run_async_in_thread(fn, *args, **kwargs):
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

try:
result = loop.run_until_complete(fn(*args, **kwargs))
return result
finally:
loop.close()

async def client_task(client_id: str) -> None:
print(f"[Client {client_id}] starting")
ssl_patch = apply_ssl_bypass()

Check failure on line 475 in kai_mcp_solution_server/tests/test_multiple_integration.py

View workflow job for this annotation

GitHub Actions / Trunk Check

ruff(F841)

[new] Local variable `ssl_patch` is assigned to but never used

client = Client(
transport="http://localhost:8087",
)

async with client:
await client.session.initialize()
print(f"[Client {client_id}] initialized")

await client.session.list_tools()
print(f"[Client {client_id}] listed tools")

print(f"[Client {client_id}] finished")

try:
self.http_server_process = subprocess.Popen(

Check notice on line 491 in kai_mcp_solution_server/tests/test_multiple_integration.py

View workflow job for this annotation

GitHub Actions / Trunk Check

bandit(B607)

[new] Starting a process with a partial executable path

Check notice on line 491 in kai_mcp_solution_server/tests/test_multiple_integration.py

View workflow job for this annotation

GitHub Actions / Trunk Check

bandit(B603)

[new] subprocess call - check for execution of untrusted input.
[
"python",
"-m",
"kai_mcp_solution_server",
"--transport",
"streamable-http",
"--port",
"8087",
],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use sys.executable instead of bare "python" (Bandit B607/B603).

More robust across environments and virtualenvs.

-            self.http_server_process = subprocess.Popen(
+            self.http_server_process = subprocess.Popen(
                 [
-                    "python",
+                    sys.executable,
                     "-m",
                     "kai_mcp_solution_server",
                     "--transport",
                     "streamable-http",
                     "--port",
                     "8087",
                 ],

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.12.2)

492-500: Starting a process with a partial executable path

(S607)

🪛 GitHub Check: Trunk Check

[notice] 491-491: bandit(B607)
[new] Starting a process with a partial executable path


[notice] 491-491: bandit(B603)
[new] subprocess call - check for execution of untrusted input.

🤖 Prompt for AI Agents
In kai_mcp_solution_server/tests/test_multiple_integration.py around lines 491
to 500, the subprocess invocation uses the literal "python" which is brittle
across environments; replace the string with sys.executable and ensure the
module imports sys at the top of the test file so the current Python interpreter
(including virtualenvs) is used when launching the server process.

stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)

stream_thread = threading.Thread(
target=stream_output, args=(self.http_server_process,)
)
stream_thread.daemon = True
stream_thread.start()

await asyncio.sleep(1) # give the server a second to start

NUM_TASKS = 1
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
# Submit each task to the thread pool and store the Future objects.
# The executor will call run_async_in_thread for each task ID.
futures = {
executor.submit(run_async_in_thread, client_task, i): i
for i in range(1, NUM_TASKS + 1)
}

# Use as_completed() to process results as they become available.
# This is non-blocking to the main thread while tasks are running.
for future in concurrent.futures.as_completed(futures):
task_id = futures[future]
try:
result = future.result()
print(
f"[Main] received result for Task {task_id}: {result}",
flush=True,
)
except Exception as exc:
print(f"[Main] Task {task_id} generated an exception: {exc}")

await asyncio.sleep(10) # wait a moment for all output to be printed

finally:
self.http_server_process.terminate()
self.http_server_process.wait()
print("Server process terminated.")
stream_thread.join()
Loading