diff --git a/mistralrs-core/src/pipeline/chat_template.rs b/mistralrs-core/src/pipeline/chat_template.rs index 943d353e52..330512a2ce 100644 --- a/mistralrs-core/src/pipeline/chat_template.rs +++ b/mistralrs-core/src/pipeline/chat_template.rs @@ -114,29 +114,33 @@ pub fn calculate_eos_tokens( } if let Some(gen_conf) = gen_conf { - let ids = match gen_conf.eos_token_id { - Either::Left(id) => vec![id], - Either::Right(ids) => ids, - }; - for id in ids { - let s = tokenizer - .decode(&[id], false) - .unwrap_or_else(|_| panic!("Unable to decode id {id})")); - if !eos_tok_ids.contains(&s) { - eos_tok_ids.push(s); + if let Some(eos_field) = gen_conf.eos_token_id { + let ids = match eos_field { + Either::Left(id) => vec![id], + Either::Right(ids) => ids, + }; + for id in ids { + let s = tokenizer + .decode(&[id], false) + .unwrap_or_else(|_| panic!("Unable to decode id {id})")); + if !eos_tok_ids.contains(&s) { + eos_tok_ids.push(s); + } } } - let ids = match gen_conf.bos_token_id { - Either::Left(id) => vec![id], - Either::Right(ids) => ids, - }; - for id in ids { - let s = tokenizer - .decode(&[id], false) - .unwrap_or_else(|_| panic!("Unable to decode id {id})")); - if !bos_tok_ids.contains(&s) { - bos_tok_ids.push(s); + if let Some(bos_field) = gen_conf.bos_token_id { + let ids = match bos_field { + Either::Left(id) => vec![id], + Either::Right(ids) => ids, + }; + for id in ids { + let s = tokenizer + .decode(&[id], false) + .unwrap_or_else(|_| panic!("Unable to decode id {id})")); + if !bos_tok_ids.contains(&s) { + bos_tok_ids.push(s); + } } } } @@ -176,10 +180,10 @@ pub fn calculate_eos_tokens( #[allow(dead_code)] #[derive(Debug, Deserialize)] pub struct GenerationConfig { - #[serde(with = "either::serde_untagged")] - bos_token_id: Either>, - #[serde(with = "either::serde_untagged")] - eos_token_id: Either>, + #[serde(with = "either::serde_untagged_optional")] + bos_token_id: Option>>, + #[serde(with = "either::serde_untagged_optional")] + eos_token_id: Option>>, } fn tojson(value: Value, kwargs: Kwargs) -> Result { diff --git a/mistralrs-core/src/pipeline/ggml.rs b/mistralrs-core/src/pipeline/ggml.rs index d53fd31fdd..2becd30776 100644 --- a/mistralrs-core/src/pipeline/ggml.rs +++ b/mistralrs-core/src/pipeline/ggml.rs @@ -343,10 +343,9 @@ impl Loader for GGMLLoader { }; let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?; - let gen_conf: Option = paths.get_gen_conf_filename().map(|f| { - serde_json::from_str(&fs::read_to_string(f).unwrap()) - .expect("bos_token_id/eos_token_id missing in generation_config.json") - }); + let gen_conf: Option = paths + .get_gen_conf_filename() + .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap()); let chat_template_explicit = paths .get_chat_template_explicit() .as_ref() diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index 3957c6e9d1..fe2cc6b87c 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -489,10 +489,9 @@ impl Loader for GGUFLoader { (None, None) }; - let gen_conf: Option = paths.get_gen_conf_filename().map(|f| { - serde_json::from_str(&fs::read_to_string(f).unwrap()) - .expect("bos_token_id/eos_token_id missing in generation_config.json") - }); + let gen_conf: Option = paths + .get_gen_conf_filename() + .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap()); let chat_template_explicit = paths .get_chat_template_explicit() .as_ref() diff --git a/mistralrs-core/src/pipeline/vision.rs b/mistralrs-core/src/pipeline/vision.rs index 2c8733a73c..bf82d2bc85 100644 --- a/mistralrs-core/src/pipeline/vision.rs +++ b/mistralrs-core/src/pipeline/vision.rs @@ -539,10 +539,9 @@ impl Loader for VisionLoader { Some(processor.get_special_tokens()), )?; - let gen_conf: Option = paths.get_gen_conf_filename().map(|f| { - serde_json::from_str(&fs::read_to_string(f).unwrap()) - .expect("bos_token_id/eos_token_id missing in generation_config.json") - }); + let gen_conf: Option = paths + .get_gen_conf_filename() + .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap()); let chat_template_explicit = paths .get_chat_template_explicit() .as_ref()