Skip to content

Commit b249593

Browse files
benbrandtas-cii
andauthored
agent2: Always finalize diffs from the edit tool (#36918)
Previously, we wouldn't finalize the diff if an error occurred during editing or the tool call was canceled. Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <[email protected]>
1 parent c14d84c commit b249593

File tree

3 files changed

+152
-6
lines changed

3 files changed

+152
-6
lines changed

crates/agent2/src/thread.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2459,6 +2459,30 @@ impl ToolCallEventStreamReceiver {
24592459
}
24602460
}
24612461

2462+
pub async fn expect_update_fields(&mut self) -> acp::ToolCallUpdateFields {
2463+
let event = self.0.next().await;
2464+
if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2465+
update,
2466+
)))) = event
2467+
{
2468+
update.fields
2469+
} else {
2470+
panic!("Expected update fields but got: {:?}", event);
2471+
}
2472+
}
2473+
2474+
pub async fn expect_diff(&mut self) -> Entity<acp_thread::Diff> {
2475+
let event = self.0.next().await;
2476+
if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateDiff(
2477+
update,
2478+
)))) = event
2479+
{
2480+
update.diff
2481+
} else {
2482+
panic!("Expected diff but got: {:?}", event);
2483+
}
2484+
}
2485+
24622486
pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
24632487
let event = self.0.next().await;
24642488
if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(

crates/agent2/src/tools/edit_file_tool.rs

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,13 @@ impl AgentTool for EditFileTool {
273273

274274
let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
275275
event_stream.update_diff(diff.clone());
276+
let _finalize_diff = util::defer({
277+
let diff = diff.downgrade();
278+
let mut cx = cx.clone();
279+
move || {
280+
diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
281+
}
282+
});
276283

277284
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
278285
let old_text = cx
@@ -389,8 +396,6 @@ impl AgentTool for EditFileTool {
389396
})
390397
.await;
391398

392-
diff.update(cx, |diff, cx| diff.finalize(cx)).ok();
393-
394399
let input_path = input.path.display();
395400
if unified_diff.is_empty() {
396401
anyhow::ensure!(
@@ -1545,6 +1550,100 @@ mod tests {
15451550
);
15461551
}
15471552

1553+
#[gpui::test]
1554+
async fn test_diff_finalization(cx: &mut TestAppContext) {
1555+
init_test(cx);
1556+
let fs = project::FakeFs::new(cx.executor());
1557+
fs.insert_tree("/", json!({"main.rs": ""})).await;
1558+
1559+
let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1560+
let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1561+
let context_server_registry =
1562+
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1563+
let model = Arc::new(FakeLanguageModel::default());
1564+
let thread = cx.new(|cx| {
1565+
Thread::new(
1566+
project.clone(),
1567+
cx.new(|_cx| ProjectContext::default()),
1568+
context_server_registry.clone(),
1569+
Templates::new(),
1570+
Some(model.clone()),
1571+
cx,
1572+
)
1573+
});
1574+
1575+
// Ensure the diff is finalized after the edit completes.
1576+
{
1577+
let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
1578+
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1579+
let edit = cx.update(|cx| {
1580+
tool.run(
1581+
EditFileToolInput {
1582+
display_description: "Edit file".into(),
1583+
path: path!("/main.rs").into(),
1584+
mode: EditFileMode::Edit,
1585+
},
1586+
stream_tx,
1587+
cx,
1588+
)
1589+
});
1590+
stream_rx.expect_update_fields().await;
1591+
let diff = stream_rx.expect_diff().await;
1592+
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1593+
cx.run_until_parked();
1594+
model.end_last_completion_stream();
1595+
edit.await.unwrap();
1596+
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1597+
}
1598+
1599+
// Ensure the diff is finalized if an error occurs while editing.
1600+
{
1601+
model.forbid_requests();
1602+
let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
1603+
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1604+
let edit = cx.update(|cx| {
1605+
tool.run(
1606+
EditFileToolInput {
1607+
display_description: "Edit file".into(),
1608+
path: path!("/main.rs").into(),
1609+
mode: EditFileMode::Edit,
1610+
},
1611+
stream_tx,
1612+
cx,
1613+
)
1614+
});
1615+
stream_rx.expect_update_fields().await;
1616+
let diff = stream_rx.expect_diff().await;
1617+
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1618+
edit.await.unwrap_err();
1619+
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1620+
model.allow_requests();
1621+
}
1622+
1623+
// Ensure the diff is finalized if the tool call gets dropped.
1624+
{
1625+
let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
1626+
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1627+
let edit = cx.update(|cx| {
1628+
tool.run(
1629+
EditFileToolInput {
1630+
display_description: "Edit file".into(),
1631+
path: path!("/main.rs").into(),
1632+
mode: EditFileMode::Edit,
1633+
},
1634+
stream_tx,
1635+
cx,
1636+
)
1637+
});
1638+
stream_rx.expect_update_fields().await;
1639+
let diff = stream_rx.expect_diff().await;
1640+
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1641+
drop(edit);
1642+
cx.run_until_parked();
1643+
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1644+
}
1645+
}
1646+
15481647
fn init_test(cx: &mut TestAppContext) {
15491648
cx.update(|cx| {
15501649
let settings_store = SettingsStore::test(cx);

crates/language_model/src/fake_provider.rs

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@ use crate::{
44
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
55
LanguageModelRequest, LanguageModelToolChoice,
66
};
7+
use anyhow::anyhow;
78
use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
89
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
910
use http_client::Result;
1011
use parking_lot::Mutex;
1112
use smol::stream::StreamExt;
12-
use std::sync::Arc;
13+
use std::sync::{
14+
Arc,
15+
atomic::{AtomicBool, Ordering::SeqCst},
16+
};
1317

1418
#[derive(Clone)]
1519
pub struct FakeLanguageModelProvider {
@@ -106,6 +110,7 @@ pub struct FakeLanguageModel {
106110
>,
107111
)>,
108112
>,
113+
forbid_requests: AtomicBool,
109114
}
110115

111116
impl Default for FakeLanguageModel {
@@ -114,11 +119,20 @@ impl Default for FakeLanguageModel {
114119
provider_id: LanguageModelProviderId::from("fake".to_string()),
115120
provider_name: LanguageModelProviderName::from("Fake".to_string()),
116121
current_completion_txs: Mutex::new(Vec::new()),
122+
forbid_requests: AtomicBool::new(false),
117123
}
118124
}
119125
}
120126

121127
impl FakeLanguageModel {
128+
pub fn allow_requests(&self) {
129+
self.forbid_requests.store(false, SeqCst);
130+
}
131+
132+
pub fn forbid_requests(&self) {
133+
self.forbid_requests.store(true, SeqCst);
134+
}
135+
122136
pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
123137
self.current_completion_txs
124138
.lock()
@@ -251,9 +265,18 @@ impl LanguageModel for FakeLanguageModel {
251265
LanguageModelCompletionError,
252266
>,
253267
> {
254-
let (tx, rx) = mpsc::unbounded();
255-
self.current_completion_txs.lock().push((request, tx));
256-
async move { Ok(rx.boxed()) }.boxed()
268+
if self.forbid_requests.load(SeqCst) {
269+
async move {
270+
Err(LanguageModelCompletionError::Other(anyhow!(
271+
"requests are forbidden"
272+
)))
273+
}
274+
.boxed()
275+
} else {
276+
let (tx, rx) = mpsc::unbounded();
277+
self.current_completion_txs.lock().push((request, tx));
278+
async move { Ok(rx.boxed()) }.boxed()
279+
}
257280
}
258281

259282
fn as_fake(&self) -> &Self {

0 commit comments

Comments
 (0)