@@ -804,6 +804,7 @@ pub enum AcpThreadEvent {
804804 Error ,
805805 LoadError ( LoadError ) ,
806806 PromptCapabilitiesUpdated ,
807+ Refusal ,
807808}
808809
809810impl EventEmitter < AcpThreadEvent > for AcpThread { }
@@ -1569,15 +1570,42 @@ impl AcpThread {
15691570 this. send_task . take ( ) ;
15701571 }
15711572
1572- // Truncate entries if the last prompt was refused.
1573+ // Handle refusal - distinguish between user prompt and tool call refusals
15731574 if let Ok ( Ok ( acp:: PromptResponse {
15741575 stop_reason : acp:: StopReason :: Refusal ,
15751576 } ) ) = result
1576- && let Some ( ( ix, _) ) = this. last_user_message ( )
15771577 {
1578- let range = ix..this. entries . len ( ) ;
1579- this. entries . truncate ( ix) ;
1580- cx. emit ( AcpThreadEvent :: EntriesRemoved ( range) ) ;
1578+ if let Some ( ( user_msg_ix, _) ) = this. last_user_message ( ) {
1579+ // Check if there's a completed tool call with results after the last user message
1580+ // This indicates the refusal is in response to tool output, not the user's prompt
1581+ let has_completed_tool_call_after_user_msg =
1582+ this. entries . iter ( ) . skip ( user_msg_ix + 1 ) . any ( |entry| {
1583+ if let AgentThreadEntry :: ToolCall ( tool_call) = entry {
1584+ // Check if the tool call has completed and has output
1585+ matches ! ( tool_call. status, ToolCallStatus :: Completed )
1586+ && tool_call. raw_output . is_some ( )
1587+ } else {
1588+ false
1589+ }
1590+ } ) ;
1591+
1592+ if has_completed_tool_call_after_user_msg {
1593+ // Refusal is due to tool output - don't truncate, just notify
1594+ // The model refused based on what the tool returned
1595+ cx. emit ( AcpThreadEvent :: Refusal ) ;
1596+ } else {
1597+ // User prompt was refused - truncate back to before the user message
1598+ let range = user_msg_ix..this. entries . len ( ) ;
1599+ if range. start < range. end {
1600+ this. entries . truncate ( user_msg_ix) ;
1601+ cx. emit ( AcpThreadEvent :: EntriesRemoved ( range) ) ;
1602+ }
1603+ cx. emit ( AcpThreadEvent :: Refusal ) ;
1604+ }
1605+ } else {
1606+ // No user message found, treat as general refusal
1607+ cx. emit ( AcpThreadEvent :: Refusal ) ;
1608+ }
15811609 }
15821610
15831611 cx. emit ( AcpThreadEvent :: Stopped ) ;
@@ -2681,6 +2709,187 @@ mod tests {
26812709 assert_eq ! ( fs. files( ) , vec![ Path :: new( path!( "/test/file-0" ) ) ] ) ;
26822710 }
26832711
2712+ #[ gpui:: test]
2713+ async fn test_tool_result_refusal ( cx : & mut TestAppContext ) {
2714+ use std:: sync:: atomic:: AtomicUsize ;
2715+ init_test ( cx) ;
2716+
2717+ let fs = FakeFs :: new ( cx. executor ( ) ) ;
2718+ let project = Project :: test ( fs, None , cx) . await ;
2719+
2720+ // Create a connection that simulates refusal after tool result
2721+ let prompt_count = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
2722+ let connection = Rc :: new ( FakeAgentConnection :: new ( ) . on_user_message ( {
2723+ let prompt_count = prompt_count. clone ( ) ;
2724+ move |_request, thread, mut cx| {
2725+ let count = prompt_count. fetch_add ( 1 , SeqCst ) ;
2726+ async move {
2727+ if count == 0 {
2728+ // First prompt: Generate a tool call with result
2729+ thread. update ( & mut cx, |thread, cx| {
2730+ thread
2731+ . handle_session_update (
2732+ acp:: SessionUpdate :: ToolCall ( acp:: ToolCall {
2733+ id : acp:: ToolCallId ( "tool1" . into ( ) ) ,
2734+ title : "Test Tool" . into ( ) ,
2735+ kind : acp:: ToolKind :: Fetch ,
2736+ status : acp:: ToolCallStatus :: Completed ,
2737+ content : vec ! [ ] ,
2738+ locations : vec ! [ ] ,
2739+ raw_input : Some ( serde_json:: json!( { "query" : "test" } ) ) ,
2740+ raw_output : Some (
2741+ serde_json:: json!( { "result" : "inappropriate content" } ) ,
2742+ ) ,
2743+ } ) ,
2744+ cx,
2745+ )
2746+ . unwrap ( ) ;
2747+ } ) ?;
2748+
2749+ // Now return refusal because of the tool result
2750+ Ok ( acp:: PromptResponse {
2751+ stop_reason : acp:: StopReason :: Refusal ,
2752+ } )
2753+ } else {
2754+ Ok ( acp:: PromptResponse {
2755+ stop_reason : acp:: StopReason :: EndTurn ,
2756+ } )
2757+ }
2758+ }
2759+ . boxed_local ( )
2760+ }
2761+ } ) ) ;
2762+
2763+ let thread = cx
2764+ . update ( |cx| connection. new_thread ( project, Path :: new ( "/test" ) , cx) )
2765+ . await
2766+ . unwrap ( ) ;
2767+
2768+ // Track if we see a Refusal event
2769+ let saw_refusal_event = Arc :: new ( std:: sync:: Mutex :: new ( false ) ) ;
2770+ let saw_refusal_event_captured = saw_refusal_event. clone ( ) ;
2771+ thread. update ( cx, |_thread, cx| {
2772+ cx. subscribe (
2773+ & thread,
2774+ move |_thread, _event_thread, event : & AcpThreadEvent , _cx| {
2775+ if matches ! ( event, AcpThreadEvent :: Refusal ) {
2776+ * saw_refusal_event_captured. lock ( ) . unwrap ( ) = true ;
2777+ }
2778+ } ,
2779+ )
2780+ . detach ( ) ;
2781+ } ) ;
2782+
2783+ // Send a user message - this will trigger tool call and then refusal
2784+ let send_task = thread. update ( cx, |thread, cx| {
2785+ thread. send (
2786+ vec ! [ acp:: ContentBlock :: Text ( acp:: TextContent {
2787+ text: "Hello" . into( ) ,
2788+ annotations: None ,
2789+ } ) ] ,
2790+ cx,
2791+ )
2792+ } ) ;
2793+ cx. background_executor . spawn ( send_task) . detach ( ) ;
2794+ cx. run_until_parked ( ) ;
2795+
2796+ // Verify that:
2797+ // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
2798+ // 2. The user message was NOT truncated
2799+ assert ! (
2800+ * saw_refusal_event. lock( ) . unwrap( ) ,
2801+ "Refusal event should be emitted for tool result refusals"
2802+ ) ;
2803+
2804+ thread. read_with ( cx, |thread, _| {
2805+ let entries = thread. entries ( ) ;
2806+ assert ! ( entries. len( ) >= 2 , "Should have user message and tool call" ) ;
2807+
2808+ // Verify user message is still there
2809+ assert ! (
2810+ matches!( entries[ 0 ] , AgentThreadEntry :: UserMessage ( _) ) ,
2811+ "User message should not be truncated"
2812+ ) ;
2813+
2814+ // Verify tool call is there with result
2815+ if let AgentThreadEntry :: ToolCall ( tool_call) = & entries[ 1 ] {
2816+ assert ! (
2817+ tool_call. raw_output. is_some( ) ,
2818+ "Tool call should have output"
2819+ ) ;
2820+ } else {
2821+ panic ! ( "Expected tool call at index 1" ) ;
2822+ }
2823+ } ) ;
2824+ }
2825+
2826+ #[ gpui:: test]
2827+ async fn test_user_prompt_refusal_emits_event ( cx : & mut TestAppContext ) {
2828+ init_test ( cx) ;
2829+
2830+ let fs = FakeFs :: new ( cx. executor ( ) ) ;
2831+ let project = Project :: test ( fs, None , cx) . await ;
2832+
2833+ let refuse_next = Arc :: new ( AtomicBool :: new ( false ) ) ;
2834+ let connection = Rc :: new ( FakeAgentConnection :: new ( ) . on_user_message ( {
2835+ let refuse_next = refuse_next. clone ( ) ;
2836+ move |_request, _thread, _cx| {
2837+ if refuse_next. load ( SeqCst ) {
2838+ async move {
2839+ Ok ( acp:: PromptResponse {
2840+ stop_reason : acp:: StopReason :: Refusal ,
2841+ } )
2842+ }
2843+ . boxed_local ( )
2844+ } else {
2845+ async move {
2846+ Ok ( acp:: PromptResponse {
2847+ stop_reason : acp:: StopReason :: EndTurn ,
2848+ } )
2849+ }
2850+ . boxed_local ( )
2851+ }
2852+ }
2853+ } ) ) ;
2854+
2855+ let thread = cx
2856+ . update ( |cx| connection. new_thread ( project, Path :: new ( path ! ( "/test" ) ) , cx) )
2857+ . await
2858+ . unwrap ( ) ;
2859+
2860+ // Track if we see a Refusal event
2861+ let saw_refusal_event = Arc :: new ( std:: sync:: Mutex :: new ( false ) ) ;
2862+ let saw_refusal_event_captured = saw_refusal_event. clone ( ) ;
2863+ thread. update ( cx, |_thread, cx| {
2864+ cx. subscribe (
2865+ & thread,
2866+ move |_thread, _event_thread, event : & AcpThreadEvent , _cx| {
2867+ if matches ! ( event, AcpThreadEvent :: Refusal ) {
2868+ * saw_refusal_event_captured. lock ( ) . unwrap ( ) = true ;
2869+ }
2870+ } ,
2871+ )
2872+ . detach ( ) ;
2873+ } ) ;
2874+
2875+ // Send a message that will be refused
2876+ refuse_next. store ( true , SeqCst ) ;
2877+ cx. update ( |cx| thread. update ( cx, |thread, cx| thread. send ( vec ! [ "hello" . into( ) ] , cx) ) )
2878+ . await
2879+ . unwrap ( ) ;
2880+
2881+ // Verify that a Refusal event WAS emitted for user prompt refusal
2882+ assert ! (
2883+ * saw_refusal_event. lock( ) . unwrap( ) ,
2884+ "Refusal event should be emitted for user prompt refusals"
2885+ ) ;
2886+
2887+ // Verify the message was truncated (user prompt refusal)
2888+ thread. read_with ( cx, |thread, cx| {
2889+ assert_eq ! ( thread. to_markdown( cx) , "" ) ;
2890+ } ) ;
2891+ }
2892+
26842893 #[ gpui:: test]
26852894 async fn test_refusal ( cx : & mut TestAppContext ) {
26862895 init_test ( cx) ;
@@ -2744,8 +2953,8 @@ mod tests {
27442953 ) ;
27452954 } ) ;
27462955
2747- // Simulate refusing the second message, ensuring the conversation gets
2748- // truncated to before sending it .
2956+ // Simulate refusing the second message. The message should be truncated
2957+ // when a user prompt is refused .
27492958 refuse_next. store ( true , SeqCst ) ;
27502959 cx. update ( |cx| thread. update ( cx, |thread, cx| thread. send ( vec ! [ "world" . into( ) ] , cx) ) )
27512960 . await
0 commit comments