Skip to content

Commit d376483

Browse files
rtfeldmannotpeter
authored andcommitted
Handle model refusal in ACP threads (#37383)
If the model refuses a prompt, we now: * Show an error if it was a user prompt (and truncate it out of the history) * Respond with a failed tool call if the refusal was for a tool call <img width="607" height="260" alt="Screenshot 2025-09-02 at 5 11 45 PM" src="https://github.com/user-attachments/assets/070b5ee7-6ad6-4a63-8395-f9a5093cc40e" /> <img width="607" height="265" alt="Screenshot 2025-09-02 at 5 11 38 PM" src="https://github.com/user-attachments/assets/98862586-390b-494e-b1f8-71d8341c8d9d" /> Release Notes: - Improve handling of model refusals in ACP threads
1 parent 76e7c78 commit d376483

File tree

6 files changed

+387
-11
lines changed

6 files changed

+387
-11
lines changed

crates/acp_thread/src/acp_thread.rs

Lines changed: 216 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,7 @@ pub enum AcpThreadEvent {
804804
Error,
805805
LoadError(LoadError),
806806
PromptCapabilitiesUpdated,
807+
Refusal,
807808
}
808809

809810
impl EventEmitter<AcpThreadEvent> for AcpThread {}
@@ -1569,15 +1570,42 @@ impl AcpThread {
15691570
this.send_task.take();
15701571
}
15711572

1572-
// Truncate entries if the last prompt was refused.
1573+
// Handle refusal - distinguish between user prompt and tool call refusals
15731574
if let Ok(Ok(acp::PromptResponse {
15741575
stop_reason: acp::StopReason::Refusal,
15751576
})) = result
1576-
&& let Some((ix, _)) = this.last_user_message()
15771577
{
1578-
let range = ix..this.entries.len();
1579-
this.entries.truncate(ix);
1580-
cx.emit(AcpThreadEvent::EntriesRemoved(range));
1578+
if let Some((user_msg_ix, _)) = this.last_user_message() {
1579+
// Check if there's a completed tool call with results after the last user message
1580+
// This indicates the refusal is in response to tool output, not the user's prompt
1581+
let has_completed_tool_call_after_user_msg =
1582+
this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1583+
if let AgentThreadEntry::ToolCall(tool_call) = entry {
1584+
// Check if the tool call has completed and has output
1585+
matches!(tool_call.status, ToolCallStatus::Completed)
1586+
&& tool_call.raw_output.is_some()
1587+
} else {
1588+
false
1589+
}
1590+
});
1591+
1592+
if has_completed_tool_call_after_user_msg {
1593+
// Refusal is due to tool output - don't truncate, just notify
1594+
// The model refused based on what the tool returned
1595+
cx.emit(AcpThreadEvent::Refusal);
1596+
} else {
1597+
// User prompt was refused - truncate back to before the user message
1598+
let range = user_msg_ix..this.entries.len();
1599+
if range.start < range.end {
1600+
this.entries.truncate(user_msg_ix);
1601+
cx.emit(AcpThreadEvent::EntriesRemoved(range));
1602+
}
1603+
cx.emit(AcpThreadEvent::Refusal);
1604+
}
1605+
} else {
1606+
// No user message found, treat as general refusal
1607+
cx.emit(AcpThreadEvent::Refusal);
1608+
}
15811609
}
15821610

15831611
cx.emit(AcpThreadEvent::Stopped);
@@ -2681,6 +2709,187 @@ mod tests {
26812709
assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
26822710
}
26832711

2712+
#[gpui::test]
2713+
async fn test_tool_result_refusal(cx: &mut TestAppContext) {
2714+
use std::sync::atomic::AtomicUsize;
2715+
init_test(cx);
2716+
2717+
let fs = FakeFs::new(cx.executor());
2718+
let project = Project::test(fs, None, cx).await;
2719+
2720+
// Create a connection that simulates refusal after tool result
2721+
let prompt_count = Arc::new(AtomicUsize::new(0));
2722+
let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2723+
let prompt_count = prompt_count.clone();
2724+
move |_request, thread, mut cx| {
2725+
let count = prompt_count.fetch_add(1, SeqCst);
2726+
async move {
2727+
if count == 0 {
2728+
// First prompt: Generate a tool call with result
2729+
thread.update(&mut cx, |thread, cx| {
2730+
thread
2731+
.handle_session_update(
2732+
acp::SessionUpdate::ToolCall(acp::ToolCall {
2733+
id: acp::ToolCallId("tool1".into()),
2734+
title: "Test Tool".into(),
2735+
kind: acp::ToolKind::Fetch,
2736+
status: acp::ToolCallStatus::Completed,
2737+
content: vec![],
2738+
locations: vec![],
2739+
raw_input: Some(serde_json::json!({"query": "test"})),
2740+
raw_output: Some(
2741+
serde_json::json!({"result": "inappropriate content"}),
2742+
),
2743+
}),
2744+
cx,
2745+
)
2746+
.unwrap();
2747+
})?;
2748+
2749+
// Now return refusal because of the tool result
2750+
Ok(acp::PromptResponse {
2751+
stop_reason: acp::StopReason::Refusal,
2752+
})
2753+
} else {
2754+
Ok(acp::PromptResponse {
2755+
stop_reason: acp::StopReason::EndTurn,
2756+
})
2757+
}
2758+
}
2759+
.boxed_local()
2760+
}
2761+
}));
2762+
2763+
let thread = cx
2764+
.update(|cx| connection.new_thread(project, Path::new("/test"), cx))
2765+
.await
2766+
.unwrap();
2767+
2768+
// Track if we see a Refusal event
2769+
let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
2770+
let saw_refusal_event_captured = saw_refusal_event.clone();
2771+
thread.update(cx, |_thread, cx| {
2772+
cx.subscribe(
2773+
&thread,
2774+
move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
2775+
if matches!(event, AcpThreadEvent::Refusal) {
2776+
*saw_refusal_event_captured.lock().unwrap() = true;
2777+
}
2778+
},
2779+
)
2780+
.detach();
2781+
});
2782+
2783+
// Send a user message - this will trigger tool call and then refusal
2784+
let send_task = thread.update(cx, |thread, cx| {
2785+
thread.send(
2786+
vec![acp::ContentBlock::Text(acp::TextContent {
2787+
text: "Hello".into(),
2788+
annotations: None,
2789+
})],
2790+
cx,
2791+
)
2792+
});
2793+
cx.background_executor.spawn(send_task).detach();
2794+
cx.run_until_parked();
2795+
2796+
// Verify that:
2797+
// 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
2798+
// 2. The user message was NOT truncated
2799+
assert!(
2800+
*saw_refusal_event.lock().unwrap(),
2801+
"Refusal event should be emitted for tool result refusals"
2802+
);
2803+
2804+
thread.read_with(cx, |thread, _| {
2805+
let entries = thread.entries();
2806+
assert!(entries.len() >= 2, "Should have user message and tool call");
2807+
2808+
// Verify user message is still there
2809+
assert!(
2810+
matches!(entries[0], AgentThreadEntry::UserMessage(_)),
2811+
"User message should not be truncated"
2812+
);
2813+
2814+
// Verify tool call is there with result
2815+
if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
2816+
assert!(
2817+
tool_call.raw_output.is_some(),
2818+
"Tool call should have output"
2819+
);
2820+
} else {
2821+
panic!("Expected tool call at index 1");
2822+
}
2823+
});
2824+
}
2825+
2826+
#[gpui::test]
2827+
async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
2828+
init_test(cx);
2829+
2830+
let fs = FakeFs::new(cx.executor());
2831+
let project = Project::test(fs, None, cx).await;
2832+
2833+
let refuse_next = Arc::new(AtomicBool::new(false));
2834+
let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2835+
let refuse_next = refuse_next.clone();
2836+
move |_request, _thread, _cx| {
2837+
if refuse_next.load(SeqCst) {
2838+
async move {
2839+
Ok(acp::PromptResponse {
2840+
stop_reason: acp::StopReason::Refusal,
2841+
})
2842+
}
2843+
.boxed_local()
2844+
} else {
2845+
async move {
2846+
Ok(acp::PromptResponse {
2847+
stop_reason: acp::StopReason::EndTurn,
2848+
})
2849+
}
2850+
.boxed_local()
2851+
}
2852+
}
2853+
}));
2854+
2855+
let thread = cx
2856+
.update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2857+
.await
2858+
.unwrap();
2859+
2860+
// Track if we see a Refusal event
2861+
let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
2862+
let saw_refusal_event_captured = saw_refusal_event.clone();
2863+
thread.update(cx, |_thread, cx| {
2864+
cx.subscribe(
2865+
&thread,
2866+
move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
2867+
if matches!(event, AcpThreadEvent::Refusal) {
2868+
*saw_refusal_event_captured.lock().unwrap() = true;
2869+
}
2870+
},
2871+
)
2872+
.detach();
2873+
});
2874+
2875+
// Send a message that will be refused
2876+
refuse_next.store(true, SeqCst);
2877+
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2878+
.await
2879+
.unwrap();
2880+
2881+
// Verify that a Refusal event WAS emitted for user prompt refusal
2882+
assert!(
2883+
*saw_refusal_event.lock().unwrap(),
2884+
"Refusal event should be emitted for user prompt refusals"
2885+
);
2886+
2887+
// Verify the message was truncated (user prompt refusal)
2888+
thread.read_with(cx, |thread, cx| {
2889+
assert_eq!(thread.to_markdown(cx), "");
2890+
});
2891+
}
2892+
26842893
#[gpui::test]
26852894
async fn test_refusal(cx: &mut TestAppContext) {
26862895
init_test(cx);
@@ -2744,8 +2953,8 @@ mod tests {
27442953
);
27452954
});
27462955

2747-
// Simulate refusing the second message, ensuring the conversation gets
2748-
// truncated to before sending it.
2956+
// Simulate refusing the second message. The message should be truncated
2957+
// when a user prompt is refused.
27492958
refuse_next.store(true, SeqCst);
27502959
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
27512960
.await

crates/agent_ui/src/acp/model_selector_popover.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ impl AcpModelSelectorPopover {
3636
pub fn toggle(&self, window: &mut Window, cx: &mut Context<Self>) {
3737
self.menu_handle.toggle(window, cx);
3838
}
39+
40+
pub fn active_model_name(&self, cx: &App) -> Option<SharedString> {
41+
self.selector
42+
.read(cx)
43+
.delegate
44+
.active_model()
45+
.map(|model| model.name.clone())
46+
}
3947
}
4048

4149
impl Render for AcpModelSelectorPopover {

0 commit comments

Comments
 (0)