@@ -6,6 +6,7 @@ use collections::BTreeMap;
66use gpui:: { App , Context , Entity , EventEmitter , Global , prelude:: * } ;
77use std:: { str:: FromStr , sync:: Arc } ;
88use thiserror:: Error ;
9+ use util:: maybe;
910
1011pub fn init ( cx : & mut App ) {
1112 let registry = cx. new ( |_cx| LanguageModelRegistry :: default ( ) ) ;
@@ -41,9 +42,7 @@ impl std::fmt::Debug for ConfigurationError {
4142#[ derive( Default ) ]
4243pub struct LanguageModelRegistry {
4344 default_model : Option < ConfiguredModel > ,
44- /// This model is automatically configured by a user's environment after
45- /// authenticating all providers. It's only used when default_model is not available.
46- environment_fallback_model : Option < ConfiguredModel > ,
45+ default_fast_model : Option < ConfiguredModel > ,
4746 inline_assistant_model : Option < ConfiguredModel > ,
4847 commit_message_model : Option < ConfiguredModel > ,
4948 thread_summary_model : Option < ConfiguredModel > ,
@@ -99,6 +98,9 @@ impl ConfiguredModel {
9998
10099pub enum Event {
101100 DefaultModelChanged ,
101+ InlineAssistantModelChanged ,
102+ CommitMessageModelChanged ,
103+ ThreadSummaryModelChanged ,
102104 ProviderStateChanged ( LanguageModelProviderId ) ,
103105 AddedProvider ( LanguageModelProviderId ) ,
104106 RemovedProvider ( LanguageModelProviderId ) ,
@@ -224,7 +226,7 @@ impl LanguageModelRegistry {
224226 cx : & mut Context < Self > ,
225227 ) {
226228 let configured_model = model. and_then ( |model| self . select_model ( model, cx) ) ;
227- self . set_inline_assistant_model ( configured_model) ;
229+ self . set_inline_assistant_model ( configured_model, cx ) ;
228230 }
229231
230232 pub fn select_commit_message_model (
@@ -233,7 +235,7 @@ impl LanguageModelRegistry {
233235 cx : & mut Context < Self > ,
234236 ) {
235237 let configured_model = model. and_then ( |model| self . select_model ( model, cx) ) ;
236- self . set_commit_message_model ( configured_model) ;
238+ self . set_commit_message_model ( configured_model, cx ) ;
237239 }
238240
239241 pub fn select_thread_summary_model (
@@ -242,7 +244,7 @@ impl LanguageModelRegistry {
242244 cx : & mut Context < Self > ,
243245 ) {
244246 let configured_model = model. and_then ( |model| self . select_model ( model, cx) ) ;
245- self . set_thread_summary_model ( configured_model) ;
247+ self . set_thread_summary_model ( configured_model, cx ) ;
246248 }
247249
248250 /// Selects and sets the inline alternatives for language models based on
@@ -276,60 +278,68 @@ impl LanguageModelRegistry {
276278 }
277279
278280 pub fn set_default_model ( & mut self , model : Option < ConfiguredModel > , cx : & mut Context < Self > ) {
279- match ( self . default_model ( ) , model. as_ref ( ) ) {
281+ match ( self . default_model . as_ref ( ) , model. as_ref ( ) ) {
280282 ( Some ( old) , Some ( new) ) if old. is_same_as ( new) => { }
281283 ( None , None ) => { }
282284 _ => cx. emit ( Event :: DefaultModelChanged ) ,
283285 }
286+ self . default_fast_model = maybe ! ( {
287+ let provider = & model. as_ref( ) ?. provider;
288+ let fast_model = provider. default_fast_model( cx) ?;
289+ Some ( ConfiguredModel {
290+ provider: provider. clone( ) ,
291+ model: fast_model,
292+ } )
293+ } ) ;
284294 self . default_model = model;
285295 }
286296
287- pub fn set_environment_fallback_model (
297+ pub fn set_inline_assistant_model (
288298 & mut self ,
289299 model : Option < ConfiguredModel > ,
290300 cx : & mut Context < Self > ,
291301 ) {
292- if self . default_model . is_none ( ) {
293- match ( self . environment_fallback_model . as_ref ( ) , model. as_ref ( ) ) {
294- ( Some ( old) , Some ( new) ) if old. is_same_as ( new) => { }
295- ( None , None ) => { }
296- _ => cx. emit ( Event :: DefaultModelChanged ) ,
297- }
302+ match ( self . inline_assistant_model . as_ref ( ) , model. as_ref ( ) ) {
303+ ( Some ( old) , Some ( new) ) if old. is_same_as ( new) => { }
304+ ( None , None ) => { }
305+ _ => cx. emit ( Event :: InlineAssistantModelChanged ) ,
298306 }
299- self . environment_fallback_model = model;
300- }
301-
302- pub fn set_inline_assistant_model ( & mut self , model : Option < ConfiguredModel > ) {
303307 self . inline_assistant_model = model;
304308 }
305309
306- pub fn set_commit_message_model ( & mut self , model : Option < ConfiguredModel > ) {
310+ pub fn set_commit_message_model (
311+ & mut self ,
312+ model : Option < ConfiguredModel > ,
313+ cx : & mut Context < Self > ,
314+ ) {
315+ match ( self . commit_message_model . as_ref ( ) , model. as_ref ( ) ) {
316+ ( Some ( old) , Some ( new) ) if old. is_same_as ( new) => { }
317+ ( None , None ) => { }
318+ _ => cx. emit ( Event :: CommitMessageModelChanged ) ,
319+ }
307320 self . commit_message_model = model;
308321 }
309322
310- pub fn set_thread_summary_model ( & mut self , model : Option < ConfiguredModel > ) {
323+ pub fn set_thread_summary_model (
324+ & mut self ,
325+ model : Option < ConfiguredModel > ,
326+ cx : & mut Context < Self > ,
327+ ) {
328+ match ( self . thread_summary_model . as_ref ( ) , model. as_ref ( ) ) {
329+ ( Some ( old) , Some ( new) ) if old. is_same_as ( new) => { }
330+ ( None , None ) => { }
331+ _ => cx. emit ( Event :: ThreadSummaryModelChanged ) ,
332+ }
311333 self . thread_summary_model = model;
312334 }
313335
314- #[ track_caller]
315336 pub fn default_model ( & self ) -> Option < ConfiguredModel > {
316337 #[ cfg( debug_assertions) ]
317338 if std:: env:: var ( "ZED_SIMULATE_NO_LLM_PROVIDER" ) . is_ok ( ) {
318339 return None ;
319340 }
320341
321- self . default_model
322- . clone ( )
323- . or_else ( || self . environment_fallback_model . clone ( ) )
324- }
325-
326- pub fn default_fast_model ( & self , cx : & App ) -> Option < ConfiguredModel > {
327- let provider = self . default_model ( ) ?. provider ;
328- let fast_model = provider. default_fast_model ( cx) ?;
329- Some ( ConfiguredModel {
330- provider,
331- model : fast_model,
332- } )
342+ self . default_model . clone ( )
333343 }
334344
335345 pub fn inline_assistant_model ( & self ) -> Option < ConfiguredModel > {
@@ -343,27 +353,27 @@ impl LanguageModelRegistry {
343353 . or_else ( || self . default_model . clone ( ) )
344354 }
345355
346- pub fn commit_message_model ( & self , cx : & App ) -> Option < ConfiguredModel > {
356+ pub fn commit_message_model ( & self ) -> Option < ConfiguredModel > {
347357 #[ cfg( debug_assertions) ]
348358 if std:: env:: var ( "ZED_SIMULATE_NO_LLM_PROVIDER" ) . is_ok ( ) {
349359 return None ;
350360 }
351361
352362 self . commit_message_model
353363 . clone ( )
354- . or_else ( || self . default_fast_model ( cx ) )
364+ . or_else ( || self . default_fast_model . clone ( ) )
355365 . or_else ( || self . default_model . clone ( ) )
356366 }
357367
358- pub fn thread_summary_model ( & self , cx : & App ) -> Option < ConfiguredModel > {
368+ pub fn thread_summary_model ( & self ) -> Option < ConfiguredModel > {
359369 #[ cfg( debug_assertions) ]
360370 if std:: env:: var ( "ZED_SIMULATE_NO_LLM_PROVIDER" ) . is_ok ( ) {
361371 return None ;
362372 }
363373
364374 self . thread_summary_model
365375 . clone ( )
366- . or_else ( || self . default_fast_model ( cx ) )
376+ . or_else ( || self . default_fast_model . clone ( ) )
367377 . or_else ( || self . default_model . clone ( ) )
368378 }
369379
@@ -400,34 +410,4 @@ mod tests {
400410 let providers = registry. read ( cx) . providers ( ) ;
401411 assert ! ( providers. is_empty( ) ) ;
402412 }
403-
404- #[ gpui:: test]
405- async fn test_configure_environment_fallback_model ( cx : & mut gpui:: TestAppContext ) {
406- let registry = cx. new ( |_| LanguageModelRegistry :: default ( ) ) ;
407-
408- let provider = FakeLanguageModelProvider :: default ( ) ;
409- registry. update ( cx, |registry, cx| {
410- registry. register_provider ( provider. clone ( ) , cx) ;
411- } ) ;
412-
413- cx. update ( |cx| provider. authenticate ( cx) ) . await . unwrap ( ) ;
414-
415- registry. update ( cx, |registry, cx| {
416- let provider = registry. provider ( & provider. id ( ) ) . unwrap ( ) ;
417-
418- registry. set_environment_fallback_model (
419- Some ( ConfiguredModel {
420- provider : provider. clone ( ) ,
421- model : provider. default_model ( cx) . unwrap ( ) ,
422- } ) ,
423- cx,
424- ) ;
425-
426- let default_model = registry. default_model ( ) . unwrap ( ) ;
427- let fallback_model = registry. environment_fallback_model . clone ( ) . unwrap ( ) ;
428-
429- assert_eq ! ( default_model. model. id( ) , fallback_model. model. id( ) ) ;
430- assert_eq ! ( default_model. provider. id( ) , fallback_model. provider. id( ) ) ;
431- } ) ;
432- }
433413}
0 commit comments