1010from fastmcp import Context , FastMCP
1111from langchain .chat_models import init_chat_model
1212from langchain .chat_models .base import BaseChatModel
13- from langchain_core . language_models . fake_chat_models import FakeChatModel
13+ from langchain_community . chat_models . fake import FakeListChatModel
1414from pydantic import BaseModel , model_validator
1515from pydantic_settings import BaseSettings , NoDecode , SettingsConfigDict
1616from sqlalchemy import URL , and_ , make_url , or_ , select
@@ -126,7 +126,9 @@ async def create(self) -> None:
126126 if self .settings .llm_params is None :
127127 raise ValueError ("LLM parameters must be provided in the settings." )
128128 elif self .settings .llm_params .get ("model" ) == "fake" :
129- self .model = FakeChatModel ()
129+ llm_params = self .settings .llm_params .copy ()
130+ llm_params .pop ("model" , None )
131+ self .model = FakeListChatModel (** llm_params )
130132 else :
131133 self .model = init_chat_model (** self .settings .llm_params )
132134
@@ -285,40 +287,28 @@ async def create_solution(
285287 # Try to match the before files. If you can't, create them.
286288 # Create the after files. Associate them with the before files.
287289 # Create the solution.
290+
288291 db_before_files : set [DBFile ] = set ()
289292 for file in before :
290- stmt = select (DBFile ).where (
291- DBFile .client_id == client_id ,
292- DBFile .uri == file .uri ,
293+ stmt = (
294+ select (DBFile )
295+ .where (
296+ DBFile .client_id == client_id ,
297+ DBFile .uri == file .uri ,
298+ )
299+ .order_by (DBFile .created_at .desc ())
293300 )
294- prev_before = (await session .execute (stmt )).scalar_one_or_none ()
301+ prev_before = (await session .execute (stmt )).scalars (). first ()
295302
296- if prev_before is None :
297- next_before = DBFile (
298- client_id = client_id ,
299- uri = file .uri ,
300- content = file .content ,
301- status = SolutionStatus .PENDING ,
302- solution_before = set (),
303- solution_after = set (),
304- next = None ,
305- )
306- session .add (next_before )
307- db_before_files .add (next_before )
308- elif prev_before .content != file .content :
303+ if prev_before is None or prev_before .content != file .content :
309304 next_before = DBFile (
310305 client_id = client_id ,
311306 uri = file .uri ,
312307 content = file .content ,
313308 status = SolutionStatus .PENDING ,
314309 solution_before = set (),
315310 solution_after = set (),
316- next = None ,
317311 )
318-
319- prev_before .status = SolutionStatus .PENDING
320- prev_before .next = next_before
321-
322312 session .add (next_before )
323313 db_before_files .add (next_before )
324314 else :
@@ -334,18 +324,17 @@ async def create_solution(
334324 status = SolutionStatus .PENDING ,
335325 solution_before = set (),
336326 solution_after = set (),
337- next = None ,
338327 )
339328
340- stmt = select (DBFile ).where (
341- DBFile .client_id == client_id ,
342- DBFile .uri == file .uri ,
329+ stmt = (
330+ select (DBFile )
331+ .where (
332+ DBFile .client_id == client_id ,
333+ DBFile .uri == file .uri ,
334+ )
335+ .order_by (DBFile .created_at .desc ())
343336 )
344337
345- previous_after = (await session .execute (stmt )).scalar_one_or_none ()
346- if previous_after is not None :
347- previous_after .next = next_after
348-
349338 db_after_files .add (next_after )
350339 session .add (next_after )
351340
@@ -400,12 +389,15 @@ async def generate_hint_v1(
400389 async with kai_ctx .session_maker .begin () as session :
401390 solutions_stmt = select (DBSolution ).where (
402391 DBSolution .client_id == client_id ,
403- DBSolution .solution_status == SolutionStatus .ACCEPTED ,
392+ or_ (
393+ DBSolution .solution_status == SolutionStatus .ACCEPTED ,
394+ DBSolution .solution_status == SolutionStatus .MODIFIED ,
395+ ),
404396 )
405397 solutions = (await session .execute (solutions_stmt )).scalars ().all ()
406398 if len (solutions ) == 0 :
407399 log (
408- f"No accepted solutions found for client { client_id } . No hint generated."
400+ f"No accepted or modified solutions found for client { client_id } . No hint generated."
409401 )
410402 return
411403
@@ -464,7 +456,10 @@ async def generate_hint_v2(
464456 async with kai_ctx .session_maker .begin () as session :
465457 solutions_stmt = select (DBSolution ).where (
466458 DBSolution .client_id == client_id ,
467- DBSolution .solution_status == SolutionStatus .ACCEPTED ,
459+ or_ (
460+ DBSolution .solution_status == SolutionStatus .ACCEPTED ,
461+ DBSolution .solution_status == SolutionStatus .MODIFIED ,
462+ ),
468463 )
469464 solutions = (await session .execute (solutions_stmt )).scalars ().all ()
470465 if len (solutions ) == 0 :
@@ -680,7 +675,9 @@ async def get_best_hint(
680675
681676 for hint in sorted (violation .hints , key = lambda h : h .created_at , reverse = True ):
682677 if any (
683- s .solution_status == SolutionStatus .ACCEPTED for s in hint .solutions
678+ s .solution_status == SolutionStatus .ACCEPTED
679+ or s .solution_status == SolutionStatus .MODIFIED
680+ for s in hint .solutions
684681 ):
685682 return GetBestHintResult (
686683 hint = hint .text or "" ,
@@ -810,7 +807,8 @@ async def accept_file(
810807 solutions_stmt = select (DBSolution ).where (DBSolution .client_id == client_id )
811808 solutions = (await session .execute (solutions_stmt )).scalars ().all ()
812809
813- files_to_add : set [DBFile ] = set ()
810+ # Files to add or remove from the solution
811+ files_to_update : set [tuple [DBSolution , DBFile ]] = set ()
814812
815813 for solution in solutions :
816814 for file in solution .after :
@@ -821,21 +819,28 @@ async def accept_file(
821819 file .status = SolutionStatus .ACCEPTED
822820 continue
823821
824- db_file = DBFile (
825- client_id = client_id ,
826- uri = solution_file .uri ,
827- content = solution_file .content ,
828- status = SolutionStatus .MODIFIED ,
829- solution_before = set (),
830- solution_after = set (),
831- next = None ,
832- )
833- files_to_add .add (db_file )
822+ files_to_update .add ((solution , file ))
823+
824+ if len (files_to_update ) != 0 :
825+ log (
826+ f"Updating { len (files_to_update )} files for client { client_id } with URI { solution_file .uri } " ,
827+ )
828+ new_file = DBFile (
829+ client_id = client_id ,
830+ uri = solution_file .uri ,
831+ content = solution_file .content ,
832+ status = SolutionStatus .MODIFIED ,
833+ solution_before = set (),
834+ solution_after = set (),
835+ )
836+ session .add (new_file )
837+ for solution , old_file in files_to_update :
838+ # Remove the old file from the solution
839+ solution .after .remove (old_file )
840+ solution .after .add (new_file )
834841
835- # NOTE: Doing it this way to avoid modifying solutions.after while iterating
836- for file in files_to_add :
837- file .solution_after .add (solution )
838- session .add (file )
842+ session .add (new_file )
843+ session .add (solution )
839844
840845 await session .flush ()
841846
@@ -851,16 +856,17 @@ async def accept_file(
851856 session .add (solution )
852857 await session .flush ()
853858
854- # print(
855- # f"Solution {solution.id} status: {solution.solution_status}",
856- # file=sys.stderr,
857- # )
859+ print (
860+ f"Solution { solution .id } status: { solution .solution_status } " ,
861+ file = sys .stderr ,
862+ )
858863
859864 if not (
860865 solution .solution_status == SolutionStatus .ACCEPTED
861866 or solution .solution_status == SolutionStatus .MODIFIED
862867 ):
863868 all_solutions_accepted_or_modified = False
869+ break
864870
865871 await session .commit ()
866872
0 commit comments