Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mistralrs-core/src/pipeline/loaders/diffusion_loaders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ pub struct FluxLoader {
impl DiffusionModelLoader for FluxLoader {
fn get_model_paths(&self, api: &ApiRepo, model_id: &Path) -> Result<Vec<PathBuf>> {
let regex = Regex::new(r"^flux\d+-(schnell|dev)\.safetensors$")?;
let flux_name = api_dir_list!(api, model_id)
let flux_name = api_dir_list!(api, model_id, true)
.filter(|x| regex.is_match(x))
.nth(0)
.with_context(|| "Expected at least 1 .safetensors file matching the FLUX regex, please raise an issue.")?;
Expand Down
44 changes: 22 additions & 22 deletions mistralrs-core/src/pipeline/macros.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#[doc(hidden)]
#[macro_export]
macro_rules! api_dir_list {
($api:expr, $model_id:expr) => {
($api:expr, $model_id:expr, $should_panic:expr) => {
if std::path::Path::new($model_id).exists() {
let listing = std::fs::read_dir($model_id);
if listing.is_err() {
Expand Down Expand Up @@ -29,7 +29,14 @@ macro_rules! api_dir_list {
.map(|x| x.rfilename.clone())
.collect::<Vec<String>>()
})
.unwrap_or_else(|e| panic!("Could not get directory listing from API: {:?}", e))
.unwrap_or_else(|e| {
if $should_panic {
panic!("Could not get directory listing from API: {:?}", e)
} else {
tracing::warn!("Could not get directory listing from API: {:?}", e);
Vec::<String>::new()
}
})
.into_iter()
}
};
Expand Down Expand Up @@ -110,7 +117,7 @@ macro_rules! get_paths {
revision.clone(),
$this.xlora_order.as_ref(),
)?;
let gen_conf = if $crate::api_dir_list!(api, model_id)
let gen_conf = if $crate::api_dir_list!(api, model_id, false)
.collect::<Vec<_>>()
.contains(&"generation_config.json".to_string())
{
Expand All @@ -123,7 +130,7 @@ macro_rules! get_paths {
} else {
None
};
let preprocessor_config = if $crate::api_dir_list!(api, model_id)
let preprocessor_config = if $crate::api_dir_list!(api, model_id, false)
.collect::<Vec<_>>()
.contains(&"preprocessor_config.json".to_string())
{
Expand All @@ -136,7 +143,7 @@ macro_rules! get_paths {
} else {
None
};
let processor_config = if $crate::api_dir_list!(api, model_id)
let processor_config = if $crate::api_dir_list!(api, model_id, false)
.collect::<Vec<_>>()
.contains(&"processor_config.json".to_string())
{
Expand All @@ -160,7 +167,7 @@ macro_rules! get_paths {
model_id
))
};
let chat_template_json_filename = if $crate::api_dir_list!(api, model_id)
let chat_template_json_filename = if $crate::api_dir_list!(api, model_id, false)
.collect::<Vec<_>>()
.contains(&"chat_template.json".to_string())
{
Expand Down Expand Up @@ -290,6 +297,7 @@ macro_rules! get_paths_gguf {
false, // Never loading UQFF
)?;

info!("GGUF file(s) {:?}", filenames);
let adapter_paths = get_xlora_paths(
this_model_id.clone(),
$this.xlora_model_id.as_ref(),
Expand All @@ -299,10 +307,10 @@ macro_rules! get_paths_gguf {
$this.xlora_order.as_ref(),
)?;

let gen_conf = if $crate::api_dir_list!(api, model_id)
.collect::<Vec<_>>()
.contains(&"generation_config.json".to_string())
{
let dir_list = $crate::api_dir_list!(api, model_id, false)
.collect::<Vec<_>>();

let gen_conf = if dir_list.contains(&"generation_config.json".to_string()) {
info!("Loading `generation_config.json` at `{}`", this_model_id);
Some($crate::api_get_file!(
api,
Expand All @@ -313,9 +321,7 @@ macro_rules! get_paths_gguf {
None
};

let preprocessor_config = if $crate::api_dir_list!(api, model_id)
.collect::<Vec<_>>()
.contains(&"preprocessor_config.json".to_string())
let preprocessor_config = if dir_list.contains(&"preprocessor_config.json".to_string())
{
info!("Loading `preprocessor_config.json` at `{}`", this_model_id);
Some($crate::api_get_file!(
Expand All @@ -327,10 +333,7 @@ macro_rules! get_paths_gguf {
None
};

let processor_config = if $crate::api_dir_list!(api, model_id)
.collect::<Vec<_>>()
.contains(&"processor_config.json".to_string())
{
let processor_config = if dir_list.contains(&"processor_config.json".to_string()) {
info!("Loading `processor_config.json` at `{}`", this_model_id);
Some($crate::api_get_file!(
api,
Expand All @@ -341,17 +344,14 @@ macro_rules! get_paths_gguf {
None
};

let tokenizer_filename = if $this.model_id.is_some() {
let tokenizer_filename = if $this.model_id.is_some() && dir_list.contains(&"tokenizer.json".to_string()) {
info!("Loading `tokenizer.json` at `{}`", this_model_id);
$crate::api_get_file!(api, "tokenizer.json", model_id)
} else {
PathBuf::from_str("")?
};

let chat_template_json_filename = if $crate::api_dir_list!(api, model_id)
.collect::<Vec<_>>()
.contains(&"chat_template.json".to_string())
{
let chat_template_json_filename = if dir_list.contains(&"chat_template.json".to_string()) {
info!("Loading `chat_template.json` at `{}`", this_model_id);
Some($crate::api_get_file!(
api,
Expand Down
8 changes: 6 additions & 2 deletions mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1128,7 +1128,9 @@ impl AnyMoePipelineMixin for NormalPipeline {
));

let mut filenames = vec![];
for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
for rfilename in
api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
{
filenames.push(api_get_file!(api, &rfilename, model_id));
}

Expand Down Expand Up @@ -1184,7 +1186,9 @@ impl AnyMoePipelineMixin for NormalPipeline {
));

let mut gate_filenames = vec![];
for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
for rfilename in
api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
{
gate_filenames.push(api_get_file!(api, &rfilename, model_id));
}
assert_eq!(
Expand Down
10 changes: 5 additions & 5 deletions mistralrs-core/src/pipeline/paths.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ pub fn get_xlora_paths(
let model_id = Path::new(&xlora_id);

// Get the path for the xlora classifier
let xlora_classifier = &api_dir_list!(api, model_id)
let xlora_classifier = &api_dir_list!(api, model_id, true)
.filter(|x| x.contains("xlora_classifier.safetensors"))
.collect::<Vec<_>>();
if xlora_classifier.len() > 1 {
Expand All @@ -94,7 +94,7 @@ pub fn get_xlora_paths(

// Get the path for the xlora config by checking all for valid versions.
// NOTE(EricLBuehler): Remove this functionality because all configs should be deserializable
let xlora_configs = &api_dir_list!(api, model_id)
let xlora_configs = &api_dir_list!(api, model_id, true)
.filter(|x| x.contains("xlora_config.json"))
.collect::<Vec<_>>();
if xlora_configs.len() > 1 {
Expand Down Expand Up @@ -135,7 +135,7 @@ pub fn get_xlora_paths(
});

// If there are adapters in the ordering file, get their names and remote paths
let adapter_files = api_dir_list!(api, model_id)
let adapter_files = api_dir_list!(api, model_id, true)
.filter_map(|name| {
if let Some(ref adapters) = xlora_order.adapters {
for adapter_name in adapters {
Expand Down Expand Up @@ -208,7 +208,7 @@ pub fn get_xlora_paths(
let mut output = HashMap::new();
for adapter in preload_adapters {
// Get the names and remote paths of the files associated with this adapter
let adapter_files = api_dir_list!(api, &adapter.adapter_model_id)
let adapter_files = api_dir_list!(api, &adapter.adapter_model_id, true)
.filter_map(|f| {
if f.contains(&adapter.name) {
Some((f, adapter.name.clone()))
Expand Down Expand Up @@ -348,7 +348,7 @@ pub fn get_model_paths(
let pickle_match = Regex::new(PICKLE_MATCH)?;

let mut filenames = vec![];
let listing = api_dir_list!(api, model_id).filter(|x| {
let listing = api_dir_list!(api, model_id, true).filter(|x| {
safetensor_match.is_match(x)
|| pickle_match.is_match(x)
|| quant_safetensor_match.is_match(x)
Expand Down
8 changes: 6 additions & 2 deletions mistralrs-core/src/pipeline/vision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,9 @@ impl AnyMoePipelineMixin for VisionPipeline {
));

let mut filenames = vec![];
for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
for rfilename in
api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
{
filenames.push(api_get_file!(api, &rfilename, model_id));
}

Expand Down Expand Up @@ -1017,7 +1019,9 @@ impl AnyMoePipelineMixin for VisionPipeline {
));

let mut gate_filenames = vec![];
for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
for rfilename in
api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
{
gate_filenames.push(api_get_file!(api, &rfilename, model_id));
}
assert_eq!(
Expand Down
Loading