Skip to content

Commit c975846

Browse files
Anthony-Eidmgsloan
authored andcommitted
ai: Auto select user model when there's no default (#36722)
This PR identifies automatic configuration options that users can select from the agent panel. If no default provider is set in their settings, the PR defaults to the first recommended option. Additionally, it updates the selected provider for a thread when a user changes the default provider through the settings file, if the thread hasn't had any queries yet. Release Notes: - agent: automatically select a language model provider if there's no user set provider. --------- Co-authored-by: Michael Sloan <[email protected]>
1 parent d53dedc commit c975846

File tree

9 files changed

+184
-122
lines changed

9 files changed

+184
-122
lines changed

crates/agent/src/thread.rs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ impl Thread {
664664
}
665665

666666
pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
667-
if self.configured_model.is_none() {
667+
if self.configured_model.is_none() || self.messages.is_empty() {
668668
self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
669669
}
670670
self.configured_model.clone()
@@ -2097,7 +2097,7 @@ impl Thread {
20972097
}
20982098

20992099
pub fn summarize(&mut self, cx: &mut Context<Self>) {
2100-
let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
2100+
let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model(cx) else {
21012101
println!("No thread summary model");
21022102
return;
21032103
};
@@ -2416,7 +2416,7 @@ impl Thread {
24162416
}
24172417

24182418
let Some(ConfiguredModel { model, provider }) =
2419-
LanguageModelRegistry::read_global(cx).thread_summary_model()
2419+
LanguageModelRegistry::read_global(cx).thread_summary_model(cx)
24202420
else {
24212421
return;
24222422
};
@@ -5410,13 +5410,10 @@ fn main() {{
54105410
}),
54115411
cx,
54125412
);
5413-
registry.set_thread_summary_model(
5414-
Some(ConfiguredModel {
5415-
provider,
5416-
model: model.clone(),
5417-
}),
5418-
cx,
5419-
);
5413+
registry.set_thread_summary_model(Some(ConfiguredModel {
5414+
provider,
5415+
model: model.clone(),
5416+
}));
54205417
})
54215418
});
54225419

crates/agent2/src/agent.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ impl NativeAgent {
228228
) -> Entity<AcpThread> {
229229
let connection = Rc::new(NativeAgentConnection(cx.entity()));
230230
let registry = LanguageModelRegistry::read_global(cx);
231-
let summarization_model = registry.thread_summary_model().map(|c| c.model);
231+
let summarization_model = registry.thread_summary_model(cx).map(|c| c.model);
232232

233233
thread_handle.update(cx, |thread, cx| {
234234
thread.set_summarization_model(summarization_model, cx);
@@ -521,7 +521,7 @@ impl NativeAgent {
521521

522522
let registry = LanguageModelRegistry::read_global(cx);
523523
let default_model = registry.default_model().map(|m| m.model);
524-
let summarization_model = registry.thread_summary_model().map(|m| m.model);
524+
let summarization_model = registry.thread_summary_model(cx).map(|m| m.model);
525525

526526
for session in self.sessions.values_mut() {
527527
session.thread.update(cx, |thread, cx| {

crates/agent2/src/tests/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1751,11 +1751,11 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
17511751
let clock = Arc::new(clock::FakeSystemClock::new());
17521752
let client = Client::new(clock, http_client, cx);
17531753
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1754+
Project::init_settings(cx);
1755+
agent_settings::init(cx);
17541756
language_model::init(client.clone(), cx);
17551757
language_models::init(user_store, client.clone(), cx);
1756-
Project::init_settings(cx);
17571758
LanguageModelRegistry::test(cx);
1758-
agent_settings::init(cx);
17591759
});
17601760
cx.executor().forbid_parking();
17611761

crates/agent_ui/src/language_model_selector.rs

Lines changed: 1 addition & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@ use feature_flags::ZedProFeatureFlag;
66
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
77
use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task};
88
use language_model::{
9-
AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
10-
LanguageModelRegistry,
9+
ConfiguredModel, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
1110
};
1211
use ordered_float::OrderedFloat;
1312
use picker::{Picker, PickerDelegate};
@@ -77,7 +76,6 @@ pub struct LanguageModelPickerDelegate {
7776
all_models: Arc<GroupedModels>,
7877
filtered_entries: Vec<LanguageModelPickerEntry>,
7978
selected_index: usize,
80-
_authenticate_all_providers_task: Task<()>,
8179
_subscriptions: Vec<Subscription>,
8280
}
8381

@@ -98,7 +96,6 @@ impl LanguageModelPickerDelegate {
9896
selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
9997
filtered_entries: entries,
10098
get_active_model: Arc::new(get_active_model),
101-
_authenticate_all_providers_task: Self::authenticate_all_providers(cx),
10299
_subscriptions: vec![cx.subscribe_in(
103100
&LanguageModelRegistry::global(cx),
104101
window,
@@ -142,56 +139,6 @@ impl LanguageModelPickerDelegate {
142139
.unwrap_or(0)
143140
}
144141

145-
/// Authenticates all providers in the [`LanguageModelRegistry`].
146-
///
147-
/// We do this so that we can populate the language selector with all of the
148-
/// models from the configured providers.
149-
fn authenticate_all_providers(cx: &mut App) -> Task<()> {
150-
let authenticate_all_providers = LanguageModelRegistry::global(cx)
151-
.read(cx)
152-
.providers()
153-
.iter()
154-
.map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
155-
.collect::<Vec<_>>();
156-
157-
cx.spawn(async move |_cx| {
158-
for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
159-
if let Err(err) = authenticate_task.await {
160-
if matches!(err, AuthenticateError::CredentialsNotFound) {
161-
// Since we're authenticating these providers in the
162-
// background for the purposes of populating the
163-
// language selector, we don't care about providers
164-
// where the credentials are not found.
165-
} else {
166-
// Some providers have noisy failure states that we
167-
// don't want to spam the logs with every time the
168-
// language model selector is initialized.
169-
//
170-
// Ideally these should have more clear failure modes
171-
// that we know are safe to ignore here, like what we do
172-
// with `CredentialsNotFound` above.
173-
match provider_id.0.as_ref() {
174-
"lmstudio" | "ollama" => {
175-
// LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
176-
//
177-
// These fail noisily, so we don't log them.
178-
}
179-
"copilot_chat" => {
180-
// Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
181-
}
182-
_ => {
183-
log::error!(
184-
"Failed to authenticate provider: {}: {err}",
185-
provider_name.0
186-
);
187-
}
188-
}
189-
}
190-
}
191-
}
192-
})
193-
}
194-
195142
pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
196143
(self.get_active_model)(cx)
197144
}

crates/git_ui/src/git_panel.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4466,7 +4466,7 @@ fn current_language_model(cx: &Context<'_, GitPanel>) -> Option<Arc<dyn Language
44664466
is_enabled
44674467
.then(|| {
44684468
let ConfiguredModel { provider, model } =
4469-
LanguageModelRegistry::read_global(cx).commit_message_model()?;
4469+
LanguageModelRegistry::read_global(cx).commit_message_model(cx)?;
44704470

44714471
provider.is_authenticated(cx).then(|| model)
44724472
})

crates/language_model/src/registry.rs

Lines changed: 67 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ use collections::BTreeMap;
66
use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
77
use std::{str::FromStr, sync::Arc};
88
use thiserror::Error;
9-
use util::maybe;
109

1110
pub fn init(cx: &mut App) {
1211
let registry = cx.new(|_cx| LanguageModelRegistry::default());
@@ -48,7 +47,9 @@ impl std::fmt::Debug for ConfigurationError {
4847
#[derive(Default)]
4948
pub struct LanguageModelRegistry {
5049
default_model: Option<ConfiguredModel>,
51-
default_fast_model: Option<ConfiguredModel>,
50+
/// This model is automatically configured by a user's environment after
51+
/// authenticating all providers. It's only used when default_model is not available.
52+
environment_fallback_model: Option<ConfiguredModel>,
5253
inline_assistant_model: Option<ConfiguredModel>,
5354
commit_message_model: Option<ConfiguredModel>,
5455
thread_summary_model: Option<ConfiguredModel>,
@@ -104,9 +105,6 @@ impl ConfiguredModel {
104105

105106
pub enum Event {
106107
DefaultModelChanged,
107-
InlineAssistantModelChanged,
108-
CommitMessageModelChanged,
109-
ThreadSummaryModelChanged,
110108
ProviderStateChanged(LanguageModelProviderId),
111109
AddedProvider(LanguageModelProviderId),
112110
RemovedProvider(LanguageModelProviderId),
@@ -238,7 +236,7 @@ impl LanguageModelRegistry {
238236
cx: &mut Context<Self>,
239237
) {
240238
let configured_model = model.and_then(|model| self.select_model(model, cx));
241-
self.set_inline_assistant_model(configured_model, cx);
239+
self.set_inline_assistant_model(configured_model);
242240
}
243241

244242
pub fn select_commit_message_model(
@@ -247,7 +245,7 @@ impl LanguageModelRegistry {
247245
cx: &mut Context<Self>,
248246
) {
249247
let configured_model = model.and_then(|model| self.select_model(model, cx));
250-
self.set_commit_message_model(configured_model, cx);
248+
self.set_commit_message_model(configured_model);
251249
}
252250

253251
pub fn select_thread_summary_model(
@@ -256,7 +254,7 @@ impl LanguageModelRegistry {
256254
cx: &mut Context<Self>,
257255
) {
258256
let configured_model = model.and_then(|model| self.select_model(model, cx));
259-
self.set_thread_summary_model(configured_model, cx);
257+
self.set_thread_summary_model(configured_model);
260258
}
261259

262260
/// Selects and sets the inline alternatives for language models based on
@@ -290,68 +288,60 @@ impl LanguageModelRegistry {
290288
}
291289

292290
pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
293-
match (self.default_model.as_ref(), model.as_ref()) {
291+
match (self.default_model(), model.as_ref()) {
294292
(Some(old), Some(new)) if old.is_same_as(new) => {}
295293
(None, None) => {}
296294
_ => cx.emit(Event::DefaultModelChanged),
297295
}
298-
self.default_fast_model = maybe!({
299-
let provider = &model.as_ref()?.provider;
300-
let fast_model = provider.default_fast_model(cx)?;
301-
Some(ConfiguredModel {
302-
provider: provider.clone(),
303-
model: fast_model,
304-
})
305-
});
306296
self.default_model = model;
307297
}
308298

309-
pub fn set_inline_assistant_model(
299+
pub fn set_environment_fallback_model(
310300
&mut self,
311301
model: Option<ConfiguredModel>,
312302
cx: &mut Context<Self>,
313303
) {
314-
match (self.inline_assistant_model.as_ref(), model.as_ref()) {
315-
(Some(old), Some(new)) if old.is_same_as(new) => {}
316-
(None, None) => {}
317-
_ => cx.emit(Event::InlineAssistantModelChanged),
304+
if self.default_model.is_none() {
305+
match (self.environment_fallback_model.as_ref(), model.as_ref()) {
306+
(Some(old), Some(new)) if old.is_same_as(new) => {}
307+
(None, None) => {}
308+
_ => cx.emit(Event::DefaultModelChanged),
309+
}
318310
}
311+
self.environment_fallback_model = model;
312+
}
313+
314+
pub fn set_inline_assistant_model(&mut self, model: Option<ConfiguredModel>) {
319315
self.inline_assistant_model = model;
320316
}
321317

322-
pub fn set_commit_message_model(
323-
&mut self,
324-
model: Option<ConfiguredModel>,
325-
cx: &mut Context<Self>,
326-
) {
327-
match (self.commit_message_model.as_ref(), model.as_ref()) {
328-
(Some(old), Some(new)) if old.is_same_as(new) => {}
329-
(None, None) => {}
330-
_ => cx.emit(Event::CommitMessageModelChanged),
331-
}
318+
pub fn set_commit_message_model(&mut self, model: Option<ConfiguredModel>) {
332319
self.commit_message_model = model;
333320
}
334321

335-
pub fn set_thread_summary_model(
336-
&mut self,
337-
model: Option<ConfiguredModel>,
338-
cx: &mut Context<Self>,
339-
) {
340-
match (self.thread_summary_model.as_ref(), model.as_ref()) {
341-
(Some(old), Some(new)) if old.is_same_as(new) => {}
342-
(None, None) => {}
343-
_ => cx.emit(Event::ThreadSummaryModelChanged),
344-
}
322+
pub fn set_thread_summary_model(&mut self, model: Option<ConfiguredModel>) {
345323
self.thread_summary_model = model;
346324
}
347325

326+
#[track_caller]
348327
pub fn default_model(&self) -> Option<ConfiguredModel> {
349328
#[cfg(debug_assertions)]
350329
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
351330
return None;
352331
}
353332

354-
self.default_model.clone()
333+
self.default_model
334+
.clone()
335+
.or_else(|| self.environment_fallback_model.clone())
336+
}
337+
338+
pub fn default_fast_model(&self, cx: &App) -> Option<ConfiguredModel> {
339+
let provider = self.default_model()?.provider;
340+
let fast_model = provider.default_fast_model(cx)?;
341+
Some(ConfiguredModel {
342+
provider,
343+
model: fast_model,
344+
})
355345
}
356346

357347
pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
@@ -365,27 +355,27 @@ impl LanguageModelRegistry {
365355
.or_else(|| self.default_model.clone())
366356
}
367357

368-
pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
358+
pub fn commit_message_model(&self, cx: &App) -> Option<ConfiguredModel> {
369359
#[cfg(debug_assertions)]
370360
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
371361
return None;
372362
}
373363

374364
self.commit_message_model
375365
.clone()
376-
.or_else(|| self.default_fast_model.clone())
366+
.or_else(|| self.default_fast_model(cx))
377367
.or_else(|| self.default_model.clone())
378368
}
379369

380-
pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
370+
pub fn thread_summary_model(&self, cx: &App) -> Option<ConfiguredModel> {
381371
#[cfg(debug_assertions)]
382372
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
383373
return None;
384374
}
385375

386376
self.thread_summary_model
387377
.clone()
388-
.or_else(|| self.default_fast_model.clone())
378+
.or_else(|| self.default_fast_model(cx))
389379
.or_else(|| self.default_model.clone())
390380
}
391381

@@ -422,4 +412,34 @@ mod tests {
422412
let providers = registry.read(cx).providers();
423413
assert!(providers.is_empty());
424414
}
415+
416+
#[gpui::test]
417+
async fn test_configure_environment_fallback_model(cx: &mut gpui::TestAppContext) {
418+
let registry = cx.new(|_| LanguageModelRegistry::default());
419+
420+
let provider = FakeLanguageModelProvider::default();
421+
registry.update(cx, |registry, cx| {
422+
registry.register_provider(provider.clone(), cx);
423+
});
424+
425+
cx.update(|cx| provider.authenticate(cx)).await.unwrap();
426+
427+
registry.update(cx, |registry, cx| {
428+
let provider = registry.provider(&provider.id()).unwrap();
429+
430+
registry.set_environment_fallback_model(
431+
Some(ConfiguredModel {
432+
provider: provider.clone(),
433+
model: provider.default_model(cx).unwrap(),
434+
}),
435+
cx,
436+
);
437+
438+
let default_model = registry.default_model().unwrap();
439+
let fallback_model = registry.environment_fallback_model.clone().unwrap();
440+
441+
assert_eq!(default_model.model.id(), fallback_model.model.id());
442+
assert_eq!(default_model.provider.id(), fallback_model.provider.id());
443+
});
444+
}
425445
}

crates/language_models/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ ollama = { workspace = true, features = ["schemars"] }
4444
open_ai = { workspace = true, features = ["schemars"] }
4545
open_router = { workspace = true, features = ["schemars"] }
4646
partial-json-fixer.workspace = true
47+
project.workspace = true
4748
release_channel.workspace = true
4849
schemars.workspace = true
4950
serde.workspace = true

0 commit comments

Comments
 (0)