diff --git a/mistralrs-core/src/pipeline/loaders/diffusion_loaders.rs b/mistralrs-core/src/pipeline/loaders/diffusion_loaders.rs index 10834b2478..d2baf5fc15 100644 --- a/mistralrs-core/src/pipeline/loaders/diffusion_loaders.rs +++ b/mistralrs-core/src/pipeline/loaders/diffusion_loaders.rs @@ -133,7 +133,7 @@ pub struct FluxLoader { impl DiffusionModelLoader for FluxLoader { fn get_model_paths(&self, api: &ApiRepo, model_id: &Path) -> Result> { let regex = Regex::new(r"^flux\d+-(schnell|dev)\.safetensors$")?; - let flux_name = api_dir_list!(api, model_id) + let flux_name = api_dir_list!(api, model_id, true) .filter(|x| regex.is_match(x)) .nth(0) .with_context(|| "Expected at least 1 .safetensors file matching the FLUX regex, please raise an issue.")?; diff --git a/mistralrs-core/src/pipeline/macros.rs b/mistralrs-core/src/pipeline/macros.rs index 8b91af6b78..2fee07dde1 100644 --- a/mistralrs-core/src/pipeline/macros.rs +++ b/mistralrs-core/src/pipeline/macros.rs @@ -1,7 +1,7 @@ #[doc(hidden)] #[macro_export] macro_rules! api_dir_list { - ($api:expr, $model_id:expr) => { + ($api:expr, $model_id:expr, $should_panic:expr) => { if std::path::Path::new($model_id).exists() { let listing = std::fs::read_dir($model_id); if listing.is_err() { @@ -29,7 +29,14 @@ macro_rules! api_dir_list { .map(|x| x.rfilename.clone()) .collect::>() }) - .unwrap_or_else(|e| panic!("Could not get directory listing from API: {:?}", e)) + .unwrap_or_else(|e| { + if $should_panic { + panic!("Could not get directory listing from API: {:?}", e) + } else { + tracing::warn!("Could not get directory listing from API: {:?}", e); + Vec::::new() + } + }) .into_iter() } }; @@ -110,7 +117,7 @@ macro_rules! get_paths { revision.clone(), $this.xlora_order.as_ref(), )?; - let gen_conf = if $crate::api_dir_list!(api, model_id) + let gen_conf = if $crate::api_dir_list!(api, model_id, false) .collect::>() .contains(&"generation_config.json".to_string()) { @@ -123,7 +130,7 @@ macro_rules! get_paths { } else { None }; - let preprocessor_config = if $crate::api_dir_list!(api, model_id) + let preprocessor_config = if $crate::api_dir_list!(api, model_id, false) .collect::>() .contains(&"preprocessor_config.json".to_string()) { @@ -136,7 +143,7 @@ macro_rules! get_paths { } else { None }; - let processor_config = if $crate::api_dir_list!(api, model_id) + let processor_config = if $crate::api_dir_list!(api, model_id, false) .collect::>() .contains(&"processor_config.json".to_string()) { @@ -160,7 +167,7 @@ macro_rules! get_paths { model_id )) }; - let chat_template_json_filename = if $crate::api_dir_list!(api, model_id) + let chat_template_json_filename = if $crate::api_dir_list!(api, model_id, false) .collect::>() .contains(&"chat_template.json".to_string()) { @@ -290,6 +297,7 @@ macro_rules! get_paths_gguf { false, // Never loading UQFF )?; + info!("GGUF file(s) {:?}", filenames); let adapter_paths = get_xlora_paths( this_model_id.clone(), $this.xlora_model_id.as_ref(), @@ -299,10 +307,10 @@ macro_rules! get_paths_gguf { $this.xlora_order.as_ref(), )?; - let gen_conf = if $crate::api_dir_list!(api, model_id) - .collect::>() - .contains(&"generation_config.json".to_string()) - { + let dir_list = $crate::api_dir_list!(api, model_id, false) + .collect::>(); + + let gen_conf = if dir_list.contains(&"generation_config.json".to_string()) { info!("Loading `generation_config.json` at `{}`", this_model_id); Some($crate::api_get_file!( api, @@ -313,9 +321,7 @@ macro_rules! get_paths_gguf { None }; - let preprocessor_config = if $crate::api_dir_list!(api, model_id) - .collect::>() - .contains(&"preprocessor_config.json".to_string()) + let preprocessor_config = if dir_list.contains(&"preprocessor_config.json".to_string()) { info!("Loading `preprocessor_config.json` at `{}`", this_model_id); Some($crate::api_get_file!( @@ -327,10 +333,7 @@ macro_rules! get_paths_gguf { None }; - let processor_config = if $crate::api_dir_list!(api, model_id) - .collect::>() - .contains(&"processor_config.json".to_string()) - { + let processor_config = if dir_list.contains(&"processor_config.json".to_string()) { info!("Loading `processor_config.json` at `{}`", this_model_id); Some($crate::api_get_file!( api, @@ -341,17 +344,14 @@ macro_rules! get_paths_gguf { None }; - let tokenizer_filename = if $this.model_id.is_some() { + let tokenizer_filename = if $this.model_id.is_some() && dir_list.contains(&"tokenizer.json".to_string()) { info!("Loading `tokenizer.json` at `{}`", this_model_id); $crate::api_get_file!(api, "tokenizer.json", model_id) } else { PathBuf::from_str("")? }; - let chat_template_json_filename = if $crate::api_dir_list!(api, model_id) - .collect::>() - .contains(&"chat_template.json".to_string()) - { + let chat_template_json_filename = if dir_list.contains(&"chat_template.json".to_string()) { info!("Loading `chat_template.json` at `{}`", this_model_id); Some($crate::api_get_file!( api, diff --git a/mistralrs-core/src/pipeline/normal.rs b/mistralrs-core/src/pipeline/normal.rs index 1a29421c14..8cdb29ff7c 100644 --- a/mistralrs-core/src/pipeline/normal.rs +++ b/mistralrs-core/src/pipeline/normal.rs @@ -1128,7 +1128,9 @@ impl AnyMoePipelineMixin for NormalPipeline { )); let mut filenames = vec![]; - for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) { + for rfilename in + api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors")) + { filenames.push(api_get_file!(api, &rfilename, model_id)); } @@ -1184,7 +1186,9 @@ impl AnyMoePipelineMixin for NormalPipeline { )); let mut gate_filenames = vec![]; - for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) { + for rfilename in + api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors")) + { gate_filenames.push(api_get_file!(api, &rfilename, model_id)); } assert_eq!( diff --git a/mistralrs-core/src/pipeline/paths.rs b/mistralrs-core/src/pipeline/paths.rs index 1a52a9a914..f63ed71e5b 100644 --- a/mistralrs-core/src/pipeline/paths.rs +++ b/mistralrs-core/src/pipeline/paths.rs @@ -80,7 +80,7 @@ pub fn get_xlora_paths( let model_id = Path::new(&xlora_id); // Get the path for the xlora classifier - let xlora_classifier = &api_dir_list!(api, model_id) + let xlora_classifier = &api_dir_list!(api, model_id, true) .filter(|x| x.contains("xlora_classifier.safetensors")) .collect::>(); if xlora_classifier.len() > 1 { @@ -94,7 +94,7 @@ pub fn get_xlora_paths( // Get the path for the xlora config by checking all for valid versions. // NOTE(EricLBuehler): Remove this functionality because all configs should be deserializable - let xlora_configs = &api_dir_list!(api, model_id) + let xlora_configs = &api_dir_list!(api, model_id, true) .filter(|x| x.contains("xlora_config.json")) .collect::>(); if xlora_configs.len() > 1 { @@ -135,7 +135,7 @@ pub fn get_xlora_paths( }); // If there are adapters in the ordering file, get their names and remote paths - let adapter_files = api_dir_list!(api, model_id) + let adapter_files = api_dir_list!(api, model_id, true) .filter_map(|name| { if let Some(ref adapters) = xlora_order.adapters { for adapter_name in adapters { @@ -208,7 +208,7 @@ pub fn get_xlora_paths( let mut output = HashMap::new(); for adapter in preload_adapters { // Get the names and remote paths of the files associated with this adapter - let adapter_files = api_dir_list!(api, &adapter.adapter_model_id) + let adapter_files = api_dir_list!(api, &adapter.adapter_model_id, true) .filter_map(|f| { if f.contains(&adapter.name) { Some((f, adapter.name.clone())) @@ -348,7 +348,7 @@ pub fn get_model_paths( let pickle_match = Regex::new(PICKLE_MATCH)?; let mut filenames = vec![]; - let listing = api_dir_list!(api, model_id).filter(|x| { + let listing = api_dir_list!(api, model_id, true).filter(|x| { safetensor_match.is_match(x) || pickle_match.is_match(x) || quant_safetensor_match.is_match(x) diff --git a/mistralrs-core/src/pipeline/vision.rs b/mistralrs-core/src/pipeline/vision.rs index f2845e6b2d..a57944a79e 100644 --- a/mistralrs-core/src/pipeline/vision.rs +++ b/mistralrs-core/src/pipeline/vision.rs @@ -961,7 +961,9 @@ impl AnyMoePipelineMixin for VisionPipeline { )); let mut filenames = vec![]; - for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) { + for rfilename in + api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors")) + { filenames.push(api_get_file!(api, &rfilename, model_id)); } @@ -1017,7 +1019,9 @@ impl AnyMoePipelineMixin for VisionPipeline { )); let mut gate_filenames = vec![]; - for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) { + for rfilename in + api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors")) + { gate_filenames.push(api_get_file!(api, &rfilename, model_id)); } assert_eq!(