@@ -6,7 +6,6 @@ use collections::BTreeMap;
66use gpui:: { App , Context , Entity , EventEmitter , Global , prelude:: * } ;
77use std:: { str:: FromStr , sync:: Arc } ;
88use thiserror:: Error ;
9- use util:: maybe;
109
1110pub fn init ( cx : & mut App ) {
1211 let registry = cx. new ( |_cx| LanguageModelRegistry :: default ( ) ) ;
@@ -48,7 +47,9 @@ impl std::fmt::Debug for ConfigurationError {
4847#[ derive( Default ) ]
4948pub struct LanguageModelRegistry {
5049 default_model : Option < ConfiguredModel > ,
51- default_fast_model : Option < ConfiguredModel > ,
50+ /// This model is automatically configured by a user's environment after
51+ /// authenticating all providers. It's only used when default_model is not available.
52+ environment_fallback_model : Option < ConfiguredModel > ,
5253 inline_assistant_model : Option < ConfiguredModel > ,
5354 commit_message_model : Option < ConfiguredModel > ,
5455 thread_summary_model : Option < ConfiguredModel > ,
@@ -104,9 +105,6 @@ impl ConfiguredModel {
104105
105106pub enum Event {
106107 DefaultModelChanged ,
107- InlineAssistantModelChanged ,
108- CommitMessageModelChanged ,
109- ThreadSummaryModelChanged ,
110108 ProviderStateChanged ( LanguageModelProviderId ) ,
111109 AddedProvider ( LanguageModelProviderId ) ,
112110 RemovedProvider ( LanguageModelProviderId ) ,
@@ -238,7 +236,7 @@ impl LanguageModelRegistry {
238236 cx : & mut Context < Self > ,
239237 ) {
240238 let configured_model = model. and_then ( |model| self . select_model ( model, cx) ) ;
241- self . set_inline_assistant_model ( configured_model, cx ) ;
239+ self . set_inline_assistant_model ( configured_model) ;
242240 }
243241
244242 pub fn select_commit_message_model (
@@ -247,7 +245,7 @@ impl LanguageModelRegistry {
247245 cx : & mut Context < Self > ,
248246 ) {
249247 let configured_model = model. and_then ( |model| self . select_model ( model, cx) ) ;
250- self . set_commit_message_model ( configured_model, cx ) ;
248+ self . set_commit_message_model ( configured_model) ;
251249 }
252250
253251 pub fn select_thread_summary_model (
@@ -256,7 +254,7 @@ impl LanguageModelRegistry {
256254 cx : & mut Context < Self > ,
257255 ) {
258256 let configured_model = model. and_then ( |model| self . select_model ( model, cx) ) ;
259- self . set_thread_summary_model ( configured_model, cx ) ;
257+ self . set_thread_summary_model ( configured_model) ;
260258 }
261259
262260 /// Selects and sets the inline alternatives for language models based on
@@ -290,68 +288,60 @@ impl LanguageModelRegistry {
290288 }
291289
292290 pub fn set_default_model ( & mut self , model : Option < ConfiguredModel > , cx : & mut Context < Self > ) {
293- match ( self . default_model . as_ref ( ) , model. as_ref ( ) ) {
291+ match ( self . default_model ( ) , model. as_ref ( ) ) {
294292 ( Some ( old) , Some ( new) ) if old. is_same_as ( new) => { }
295293 ( None , None ) => { }
296294 _ => cx. emit ( Event :: DefaultModelChanged ) ,
297295 }
298- self . default_fast_model = maybe ! ( {
299- let provider = & model. as_ref( ) ?. provider;
300- let fast_model = provider. default_fast_model( cx) ?;
301- Some ( ConfiguredModel {
302- provider: provider. clone( ) ,
303- model: fast_model,
304- } )
305- } ) ;
306296 self . default_model = model;
307297 }
308298
309- pub fn set_inline_assistant_model (
299+ pub fn set_environment_fallback_model (
310300 & mut self ,
311301 model : Option < ConfiguredModel > ,
312302 cx : & mut Context < Self > ,
313303 ) {
314- match ( self . inline_assistant_model . as_ref ( ) , model. as_ref ( ) ) {
315- ( Some ( old) , Some ( new) ) if old. is_same_as ( new) => { }
316- ( None , None ) => { }
317- _ => cx. emit ( Event :: InlineAssistantModelChanged ) ,
304+ if self . default_model . is_none ( ) {
305+ match ( self . environment_fallback_model . as_ref ( ) , model. as_ref ( ) ) {
306+ ( Some ( old) , Some ( new) ) if old. is_same_as ( new) => { }
307+ ( None , None ) => { }
308+ _ => cx. emit ( Event :: DefaultModelChanged ) ,
309+ }
318310 }
311+ self . environment_fallback_model = model;
312+ }
313+
314+ pub fn set_inline_assistant_model ( & mut self , model : Option < ConfiguredModel > ) {
319315 self . inline_assistant_model = model;
320316 }
321317
322- pub fn set_commit_message_model (
323- & mut self ,
324- model : Option < ConfiguredModel > ,
325- cx : & mut Context < Self > ,
326- ) {
327- match ( self . commit_message_model . as_ref ( ) , model. as_ref ( ) ) {
328- ( Some ( old) , Some ( new) ) if old. is_same_as ( new) => { }
329- ( None , None ) => { }
330- _ => cx. emit ( Event :: CommitMessageModelChanged ) ,
331- }
318+ pub fn set_commit_message_model ( & mut self , model : Option < ConfiguredModel > ) {
332319 self . commit_message_model = model;
333320 }
334321
335- pub fn set_thread_summary_model (
336- & mut self ,
337- model : Option < ConfiguredModel > ,
338- cx : & mut Context < Self > ,
339- ) {
340- match ( self . thread_summary_model . as_ref ( ) , model. as_ref ( ) ) {
341- ( Some ( old) , Some ( new) ) if old. is_same_as ( new) => { }
342- ( None , None ) => { }
343- _ => cx. emit ( Event :: ThreadSummaryModelChanged ) ,
344- }
322+ pub fn set_thread_summary_model ( & mut self , model : Option < ConfiguredModel > ) {
345323 self . thread_summary_model = model;
346324 }
347325
326+ #[ track_caller]
348327 pub fn default_model ( & self ) -> Option < ConfiguredModel > {
349328 #[ cfg( debug_assertions) ]
350329 if std:: env:: var ( "ZED_SIMULATE_NO_LLM_PROVIDER" ) . is_ok ( ) {
351330 return None ;
352331 }
353332
354- self . default_model . clone ( )
333+ self . default_model
334+ . clone ( )
335+ . or_else ( || self . environment_fallback_model . clone ( ) )
336+ }
337+
338+ pub fn default_fast_model ( & self , cx : & App ) -> Option < ConfiguredModel > {
339+ let provider = self . default_model ( ) ?. provider ;
340+ let fast_model = provider. default_fast_model ( cx) ?;
341+ Some ( ConfiguredModel {
342+ provider,
343+ model : fast_model,
344+ } )
355345 }
356346
357347 pub fn inline_assistant_model ( & self ) -> Option < ConfiguredModel > {
@@ -365,27 +355,27 @@ impl LanguageModelRegistry {
365355 . or_else ( || self . default_model . clone ( ) )
366356 }
367357
368- pub fn commit_message_model ( & self ) -> Option < ConfiguredModel > {
358+ pub fn commit_message_model ( & self , cx : & App ) -> Option < ConfiguredModel > {
369359 #[ cfg( debug_assertions) ]
370360 if std:: env:: var ( "ZED_SIMULATE_NO_LLM_PROVIDER" ) . is_ok ( ) {
371361 return None ;
372362 }
373363
374364 self . commit_message_model
375365 . clone ( )
376- . or_else ( || self . default_fast_model . clone ( ) )
366+ . or_else ( || self . default_fast_model ( cx ) )
377367 . or_else ( || self . default_model . clone ( ) )
378368 }
379369
380- pub fn thread_summary_model ( & self ) -> Option < ConfiguredModel > {
370+ pub fn thread_summary_model ( & self , cx : & App ) -> Option < ConfiguredModel > {
381371 #[ cfg( debug_assertions) ]
382372 if std:: env:: var ( "ZED_SIMULATE_NO_LLM_PROVIDER" ) . is_ok ( ) {
383373 return None ;
384374 }
385375
386376 self . thread_summary_model
387377 . clone ( )
388- . or_else ( || self . default_fast_model . clone ( ) )
378+ . or_else ( || self . default_fast_model ( cx ) )
389379 . or_else ( || self . default_model . clone ( ) )
390380 }
391381
@@ -422,4 +412,34 @@ mod tests {
422412 let providers = registry. read ( cx) . providers ( ) ;
423413 assert ! ( providers. is_empty( ) ) ;
424414 }
415+
416+ #[ gpui:: test]
417+ async fn test_configure_environment_fallback_model ( cx : & mut gpui:: TestAppContext ) {
418+ let registry = cx. new ( |_| LanguageModelRegistry :: default ( ) ) ;
419+
420+ let provider = FakeLanguageModelProvider :: default ( ) ;
421+ registry. update ( cx, |registry, cx| {
422+ registry. register_provider ( provider. clone ( ) , cx) ;
423+ } ) ;
424+
425+ cx. update ( |cx| provider. authenticate ( cx) ) . await . unwrap ( ) ;
426+
427+ registry. update ( cx, |registry, cx| {
428+ let provider = registry. provider ( & provider. id ( ) ) . unwrap ( ) ;
429+
430+ registry. set_environment_fallback_model (
431+ Some ( ConfiguredModel {
432+ provider : provider. clone ( ) ,
433+ model : provider. default_model ( cx) . unwrap ( ) ,
434+ } ) ,
435+ cx,
436+ ) ;
437+
438+ let default_model = registry. default_model ( ) . unwrap ( ) ;
439+ let fallback_model = registry. environment_fallback_model . clone ( ) . unwrap ( ) ;
440+
441+ assert_eq ! ( default_model. model. id( ) , fallback_model. model. id( ) ) ;
442+ assert_eq ! ( default_model. provider. id( ) , fallback_model. provider. id( ) ) ;
443+ } ) ;
444+ }
425445}
0 commit comments