Skip to content

Commit 6656f10

Browse files
committed
👻 Fixed hint generation bugs and added tests
Signed-off-by: JonahSussman <[email protected]>
1 parent 27fb9b4 commit 6656f10

File tree

8 files changed

+1369
-67
lines changed

8 files changed

+1369
-67
lines changed

kai_mcp_solution_server/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ dependencies = [
1010
"asyncpg>=0.30.0",
1111
"fastmcp>=2.8.0",
1212
"langchain>=0.3.19",
13+
"langchain-community>=0.3.25",
1314
"langchain-openai>=0.3.3",
1415
"psycopg2-binary>=2.9.10",
1516
"pydantic>=2.10.6",

kai_mcp_solution_server/src/kai_mcp_solution_server/db/dao.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -332,18 +332,15 @@ class DBFile(Base):
332332
__tablename__ = "kai_files"
333333

334334
id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True)
335+
created_at: Mapped[datetime] = mapped_column(
336+
DateTime(timezone=True),
337+
server_default=func.now(),
338+
nullable=False,
339+
init=False,
340+
)
335341
uri: Mapped[str]
336342
client_id: Mapped[str]
337343
content: Mapped[str]
338-
next_id: Mapped[int | None] = mapped_column(
339-
ForeignKey("kai_files.id", ondelete="SET NULL", onupdate="CASCADE"),
340-
init=False,
341-
)
342-
next: Mapped[DBFile | None] = relationship(
343-
"DBFile",
344-
uselist=False,
345-
lazy="selectin",
346-
)
347344
status: Mapped[SolutionStatus]
348345

349346
solution_before: Mapped[set[DBSolution]] = relationship(

kai_mcp_solution_server/src/kai_mcp_solution_server/server.py

Lines changed: 61 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from fastmcp import Context, FastMCP
1111
from langchain.chat_models import init_chat_model
1212
from langchain.chat_models.base import BaseChatModel
13-
from langchain_core.language_models.fake_chat_models import FakeChatModel
13+
from langchain_community.chat_models.fake import FakeListChatModel
1414
from pydantic import BaseModel, model_validator
1515
from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict
1616
from sqlalchemy import URL, and_, make_url, or_, select
@@ -126,7 +126,9 @@ async def create(self) -> None:
126126
if self.settings.llm_params is None:
127127
raise ValueError("LLM parameters must be provided in the settings.")
128128
elif self.settings.llm_params.get("model") == "fake":
129-
self.model = FakeChatModel()
129+
llm_params = self.settings.llm_params.copy()
130+
llm_params.pop("model", None)
131+
self.model = FakeListChatModel(**llm_params)
130132
else:
131133
self.model = init_chat_model(**self.settings.llm_params)
132134

@@ -285,40 +287,28 @@ async def create_solution(
285287
# Try to match the before files. If you can't, create them.
286288
# Create the after files. Associate them with the before files.
287289
# Create the solution.
290+
288291
db_before_files: set[DBFile] = set()
289292
for file in before:
290-
stmt = select(DBFile).where(
291-
DBFile.client_id == client_id,
292-
DBFile.uri == file.uri,
293+
stmt = (
294+
select(DBFile)
295+
.where(
296+
DBFile.client_id == client_id,
297+
DBFile.uri == file.uri,
298+
)
299+
.order_by(DBFile.created_at.desc())
293300
)
294-
prev_before = (await session.execute(stmt)).scalar_one_or_none()
301+
prev_before = (await session.execute(stmt)).scalars().first()
295302

296-
if prev_before is None:
297-
next_before = DBFile(
298-
client_id=client_id,
299-
uri=file.uri,
300-
content=file.content,
301-
status=SolutionStatus.PENDING,
302-
solution_before=set(),
303-
solution_after=set(),
304-
next=None,
305-
)
306-
session.add(next_before)
307-
db_before_files.add(next_before)
308-
elif prev_before.content != file.content:
303+
if prev_before is None or prev_before.content != file.content:
309304
next_before = DBFile(
310305
client_id=client_id,
311306
uri=file.uri,
312307
content=file.content,
313308
status=SolutionStatus.PENDING,
314309
solution_before=set(),
315310
solution_after=set(),
316-
next=None,
317311
)
318-
319-
prev_before.status = SolutionStatus.PENDING
320-
prev_before.next = next_before
321-
322312
session.add(next_before)
323313
db_before_files.add(next_before)
324314
else:
@@ -334,18 +324,17 @@ async def create_solution(
334324
status=SolutionStatus.PENDING,
335325
solution_before=set(),
336326
solution_after=set(),
337-
next=None,
338327
)
339328

340-
stmt = select(DBFile).where(
341-
DBFile.client_id == client_id,
342-
DBFile.uri == file.uri,
329+
stmt = (
330+
select(DBFile)
331+
.where(
332+
DBFile.client_id == client_id,
333+
DBFile.uri == file.uri,
334+
)
335+
.order_by(DBFile.created_at.desc())
343336
)
344337

345-
previous_after = (await session.execute(stmt)).scalar_one_or_none()
346-
if previous_after is not None:
347-
previous_after.next = next_after
348-
349338
db_after_files.add(next_after)
350339
session.add(next_after)
351340

@@ -400,12 +389,15 @@ async def generate_hint_v1(
400389
async with kai_ctx.session_maker.begin() as session:
401390
solutions_stmt = select(DBSolution).where(
402391
DBSolution.client_id == client_id,
403-
DBSolution.solution_status == SolutionStatus.ACCEPTED,
392+
or_(
393+
DBSolution.solution_status == SolutionStatus.ACCEPTED,
394+
DBSolution.solution_status == SolutionStatus.MODIFIED,
395+
),
404396
)
405397
solutions = (await session.execute(solutions_stmt)).scalars().all()
406398
if len(solutions) == 0:
407399
log(
408-
f"No accepted solutions found for client {client_id}. No hint generated."
400+
f"No accepted or modified solutions found for client {client_id}. No hint generated."
409401
)
410402
return
411403

@@ -464,7 +456,10 @@ async def generate_hint_v2(
464456
async with kai_ctx.session_maker.begin() as session:
465457
solutions_stmt = select(DBSolution).where(
466458
DBSolution.client_id == client_id,
467-
DBSolution.solution_status == SolutionStatus.ACCEPTED,
459+
or_(
460+
DBSolution.solution_status == SolutionStatus.ACCEPTED,
461+
DBSolution.solution_status == SolutionStatus.MODIFIED,
462+
),
468463
)
469464
solutions = (await session.execute(solutions_stmt)).scalars().all()
470465
if len(solutions) == 0:
@@ -680,7 +675,9 @@ async def get_best_hint(
680675

681676
for hint in sorted(violation.hints, key=lambda h: h.created_at, reverse=True):
682677
if any(
683-
s.solution_status == SolutionStatus.ACCEPTED for s in hint.solutions
678+
s.solution_status == SolutionStatus.ACCEPTED
679+
or s.solution_status == SolutionStatus.MODIFIED
680+
for s in hint.solutions
684681
):
685682
return GetBestHintResult(
686683
hint=hint.text or "",
@@ -810,7 +807,8 @@ async def accept_file(
810807
solutions_stmt = select(DBSolution).where(DBSolution.client_id == client_id)
811808
solutions = (await session.execute(solutions_stmt)).scalars().all()
812809

813-
files_to_add: set[DBFile] = set()
810+
# Files to add or remove from the solution
811+
files_to_update: set[tuple[DBSolution, DBFile]] = set()
814812

815813
for solution in solutions:
816814
for file in solution.after:
@@ -821,21 +819,28 @@ async def accept_file(
821819
file.status = SolutionStatus.ACCEPTED
822820
continue
823821

824-
db_file = DBFile(
825-
client_id=client_id,
826-
uri=solution_file.uri,
827-
content=solution_file.content,
828-
status=SolutionStatus.MODIFIED,
829-
solution_before=set(),
830-
solution_after=set(),
831-
next=None,
832-
)
833-
files_to_add.add(db_file)
822+
files_to_update.add((solution, file))
823+
824+
if len(files_to_update) != 0:
825+
log(
826+
f"Updating {len(files_to_update)} files for client {client_id} with URI {solution_file.uri}",
827+
)
828+
new_file = DBFile(
829+
client_id=client_id,
830+
uri=solution_file.uri,
831+
content=solution_file.content,
832+
status=SolutionStatus.MODIFIED,
833+
solution_before=set(),
834+
solution_after=set(),
835+
)
836+
session.add(new_file)
837+
for solution, old_file in files_to_update:
838+
# Remove the old file from the solution
839+
solution.after.remove(old_file)
840+
solution.after.add(new_file)
834841

835-
# NOTE: Doing it this way to avoid modifying solutions.after while iterating
836-
for file in files_to_add:
837-
file.solution_after.add(solution)
838-
session.add(file)
842+
session.add(new_file)
843+
session.add(solution)
839844

840845
await session.flush()
841846

@@ -851,16 +856,17 @@ async def accept_file(
851856
session.add(solution)
852857
await session.flush()
853858

854-
# print(
855-
# f"Solution {solution.id} status: {solution.solution_status}",
856-
# file=sys.stderr,
857-
# )
859+
print(
860+
f"Solution {solution.id} status: {solution.solution_status}",
861+
file=sys.stderr,
862+
)
858863

859864
if not (
860865
solution.solution_status == SolutionStatus.ACCEPTED
861866
or solution.solution_status == SolutionStatus.MODIFIED
862867
):
863868
all_solutions_accepted_or_modified = False
869+
break
864870

865871
await session.commit()
866872

0 commit comments

Comments
 (0)