Skip to content

Commit 24ddeee

Browse files
committed
Make bos/eos token IDs optional (#1493)
1 parent 676bd47 commit 24ddeee

File tree

4 files changed

+37
-36
lines changed

4 files changed

+37
-36
lines changed

mistralrs-core/src/pipeline/chat_template.rs

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -114,29 +114,33 @@ pub fn calculate_eos_tokens(
114114
}
115115

116116
if let Some(gen_conf) = gen_conf {
117-
let ids = match gen_conf.eos_token_id {
118-
Either::Left(id) => vec![id],
119-
Either::Right(ids) => ids,
120-
};
121-
for id in ids {
122-
let s = tokenizer
123-
.decode(&[id], false)
124-
.unwrap_or_else(|_| panic!("Unable to decode id {id})"));
125-
if !eos_tok_ids.contains(&s) {
126-
eos_tok_ids.push(s);
117+
if let Some(eos_field) = gen_conf.eos_token_id {
118+
let ids = match eos_field {
119+
Either::Left(id) => vec![id],
120+
Either::Right(ids) => ids,
121+
};
122+
for id in ids {
123+
let s = tokenizer
124+
.decode(&[id], false)
125+
.unwrap_or_else(|_| panic!("Unable to decode id {id})"));
126+
if !eos_tok_ids.contains(&s) {
127+
eos_tok_ids.push(s);
128+
}
127129
}
128130
}
129131

130-
let ids = match gen_conf.bos_token_id {
131-
Either::Left(id) => vec![id],
132-
Either::Right(ids) => ids,
133-
};
134-
for id in ids {
135-
let s = tokenizer
136-
.decode(&[id], false)
137-
.unwrap_or_else(|_| panic!("Unable to decode id {id})"));
138-
if !bos_tok_ids.contains(&s) {
139-
bos_tok_ids.push(s);
132+
if let Some(bos_field) = gen_conf.bos_token_id {
133+
let ids = match bos_field {
134+
Either::Left(id) => vec![id],
135+
Either::Right(ids) => ids,
136+
};
137+
for id in ids {
138+
let s = tokenizer
139+
.decode(&[id], false)
140+
.unwrap_or_else(|_| panic!("Unable to decode id {id})"));
141+
if !bos_tok_ids.contains(&s) {
142+
bos_tok_ids.push(s);
143+
}
140144
}
141145
}
142146
}
@@ -176,10 +180,10 @@ pub fn calculate_eos_tokens(
176180
#[allow(dead_code)]
177181
#[derive(Debug, Deserialize)]
178182
pub struct GenerationConfig {
179-
#[serde(with = "either::serde_untagged")]
180-
bos_token_id: Either<u32, Vec<u32>>,
181-
#[serde(with = "either::serde_untagged")]
182-
eos_token_id: Either<u32, Vec<u32>>,
183+
#[serde(with = "either::serde_untagged_optional")]
184+
bos_token_id: Option<Either<u32, Vec<u32>>>,
185+
#[serde(with = "either::serde_untagged_optional")]
186+
eos_token_id: Option<Either<u32, Vec<u32>>>,
183187
}
184188

185189
fn tojson(value: Value, kwargs: Kwargs) -> Result<Value, Error> {

mistralrs-core/src/pipeline/ggml.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -343,10 +343,9 @@ impl Loader for GGMLLoader {
343343
};
344344

345345
let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
346-
let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().map(|f| {
347-
serde_json::from_str(&fs::read_to_string(f).unwrap())
348-
.expect("bos_token_id/eos_token_id missing in generation_config.json")
349-
});
346+
let gen_conf: Option<GenerationConfig> = paths
347+
.get_gen_conf_filename()
348+
.map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
350349
let chat_template_explicit = paths
351350
.get_chat_template_explicit()
352351
.as_ref()

mistralrs-core/src/pipeline/gguf.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -490,10 +490,9 @@ impl Loader for GGUFLoader {
490490
(None, None)
491491
};
492492

493-
let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().map(|f| {
494-
serde_json::from_str(&fs::read_to_string(f).unwrap())
495-
.expect("bos_token_id/eos_token_id missing in generation_config.json")
496-
});
493+
let gen_conf: Option<GenerationConfig> = paths
494+
.get_gen_conf_filename()
495+
.map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
497496
let chat_template_explicit = paths
498497
.get_chat_template_explicit()
499498
.as_ref()

mistralrs-core/src/pipeline/vision.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -539,10 +539,9 @@ impl Loader for VisionLoader {
539539
Some(processor.get_special_tokens()),
540540
)?;
541541

542-
let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().map(|f| {
543-
serde_json::from_str(&fs::read_to_string(f).unwrap())
544-
.expect("bos_token_id/eos_token_id missing in generation_config.json")
545-
});
542+
let gen_conf: Option<GenerationConfig> = paths
543+
.get_gen_conf_filename()
544+
.map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
546545
let chat_template_explicit = paths
547546
.get_chat_template_explicit()
548547
.as_ref()

0 commit comments

Comments
 (0)