Skip to content

Commit 237b0db

Browse files
authored
Make Sequence::set_toks more safe (#1190)
1 parent b8237b2 commit 237b0db

File tree

8 files changed

+24
-47
lines changed

8 files changed

+24
-47
lines changed

mistralrs-core/src/sequence.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::{
22
get_mut_group,
3-
pipeline::LayerCaches,
3+
pipeline::{text_models_inputs_processor::PagedAttentionMeta, LayerCaches},
44
response::{ChatCompletionChunkResponse, Choice, ChunkChoice, Response, SYSTEM_FINGERPRINT},
55
sampler::{Logprobs, Sampler},
66
ChatCompletionResponse, Usage,
@@ -480,7 +480,11 @@ impl Sequence {
480480
}
481481

482482
/// This will also set prompt_len
483-
pub(crate) fn set_toks(&mut self, toks: Vec<u32>) {
483+
pub(crate) fn set_toks_and_reallocate(
484+
&mut self,
485+
toks: Vec<u32>,
486+
paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
487+
) {
484488
self.tokens.clone_from(&toks);
485489
self.prompt_len = self.tokens.len();
486490
// Handle possible block engine
@@ -495,6 +499,12 @@ impl Sequence {
495499
}
496500
self.custom_metadata
497501
.append_tokens_to_blocks(toks.iter().map(|x| *x as usize).collect::<Vec<_>>());
502+
503+
if let Some(metadata) = paged_attn_metadata {
504+
// Free and then reallocate as appropriate
505+
metadata.block_engine.free_sequence(*self.id());
506+
metadata.block_engine.allocate(self);
507+
}
498508
}
499509

500510
pub fn completion_bytes(&self) -> &[u8] {

mistralrs-core/src/vision_models/idefics3/inputs_processor.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -250,13 +250,8 @@ impl InputsProcessor for Idefics3ImageProcessor {
250250

251251
let ids = toks.get_ids().to_vec();
252252
all_ids.push(ids.clone());
253-
seq.set_toks(ids);
254253

255-
if let Some(ref mut metadata) = paged_attn_metadata {
256-
// Free and then reallocate as appropriate
257-
metadata.block_engine.free_sequence(*seq.id());
258-
metadata.block_engine.allocate(*seq);
259-
}
254+
seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
260255
}
261256

262257
let mut all_ids_new = Vec::new();

mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -262,17 +262,13 @@ impl InputsProcessor for LLaVAInputProcessor {
262262
input_ids.extend(item);
263263
}
264264
// NOTE(EricLBuehler): Casting to u32 is fine, we don't care about the other toks
265-
seq.set_toks(
265+
seq.set_toks_and_reallocate(
266266
input_ids
267267
.iter()
268268
.map(|x| if *x < 0 { 0u32 } else { *x as u32 })
269269
.collect::<Vec<_>>(),
270+
paged_attn_metadata.as_mut(),
270271
);
271-
if let Some(ref mut metadata) = paged_attn_metadata {
272-
// Free and then reallocate as appropriate
273-
metadata.block_engine.free_sequence(*seq.id());
274-
metadata.block_engine.allocate(*seq);
275-
}
276272

277273
toks.push(input_ids);
278274
}

mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -306,17 +306,13 @@ impl InputsProcessor for LLaVANextInputProcessor {
306306
input_ids.extend(item);
307307
}
308308
// NOTE(EricLBuehler): Casting to u32 is fine, we don't care about the other toks
309-
seq.set_toks(
309+
seq.set_toks_and_reallocate(
310310
input_ids
311311
.iter()
312312
.map(|x| if *x < 0 { 0u32 } else { *x as u32 })
313313
.collect::<Vec<_>>(),
314+
paged_attn_metadata.as_mut(),
314315
);
315-
if let Some(ref mut metadata) = paged_attn_metadata {
316-
// Free and then reallocate as appropriate
317-
metadata.block_engine.free_sequence(*seq.id());
318-
metadata.block_engine.allocate(*seq);
319-
}
320316

321317
toks.push(input_ids);
322318
}

mistralrs-core/src/vision_models/minicpmo/inputs_processor.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -315,12 +315,7 @@ impl InputsProcessor for MiniCpmOImageProcessor {
315315
.get_ids()
316316
.to_vec();
317317

318-
seq.set_toks(input_ids.clone());
319-
if let Some(ref mut metadata) = paged_attn_metadata {
320-
// Free and then reallocate as appropriate
321-
metadata.block_engine.free_sequence(*seq.id());
322-
metadata.block_engine.allocate(*seq);
323-
}
318+
seq.set_toks_and_reallocate(input_ids.clone(), paged_attn_metadata.as_mut());
324319

325320
let image_start_idx = input_ids
326321
.iter()

mistralrs-core/src/vision_models/phi3/phi3_inputs_processor.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -299,17 +299,13 @@ impl InputsProcessor for Phi3InputsProcessor {
299299
}
300300

301301
// NOTE(EricLBuehler): Casting to u32 is fine, we don't care about the other toks
302-
seq.set_toks(
302+
seq.set_toks_and_reallocate(
303303
input_ids
304304
.iter()
305305
.map(|x| if *x < 0 { 0u32 } else { *x as u32 })
306306
.collect::<Vec<_>>(),
307+
paged_attn_metadata.as_mut(),
307308
);
308-
if let Some(ref mut metadata) = paged_attn_metadata {
309-
// Free and then reallocate as appropriate
310-
metadata.block_engine.free_sequence(*seq.id());
311-
metadata.block_engine.allocate(*seq);
312-
}
313309

314310
toks.push(input_ids);
315311
}

mistralrs-core/src/vision_models/phi4/inputs_processor.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -241,12 +241,13 @@ impl InputsProcessor for Phi4MMInputsProcessor {
241241
.replace_all(&detokenized, IMAGE_SPECIAL_TOKEN)
242242
.to_string();
243243

244-
seq.set_toks(
244+
seq.set_toks_and_reallocate(
245245
tokenizer
246246
.encode(detokenized.clone(), true)
247247
.expect("Encode failed")
248248
.get_ids()
249249
.to_vec(),
250+
paged_attn_metadata.as_mut(),
250251
);
251252

252253
seq.set_initial_prompt(detokenized);
@@ -265,16 +266,10 @@ impl InputsProcessor for Phi4MMInputsProcessor {
265266
let mut new_ids = seq.get_toks()[..i].to_vec();
266267
new_ids.extend(vec![token_id; *token_count]);
267268
new_ids.extend(seq.get_toks()[i + 1..].to_vec());
268-
seq.set_toks(new_ids);
269+
seq.set_toks_and_reallocate(new_ids, paged_attn_metadata.as_mut());
269270
i += token_count;
270271
}
271272
toks.push(seq.get_toks().to_vec());
272-
273-
if let Some(ref mut metadata) = paged_attn_metadata {
274-
// Free and then reallocate as appropriate
275-
metadata.block_engine.free_sequence(*seq.id());
276-
metadata.block_engine.allocate(*seq);
277-
}
278273
}
279274

280275
let iter = if is_prompt {

mistralrs-core/src/vision_models/qwen2vl/inputs_processor.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -393,13 +393,7 @@ impl InputsProcessor for Qwen2VLImageProcessor {
393393
let continuous_vid_pad = find_sequences(&ids, vid_pad[0]);
394394
all_continuous_vid_pad.push(continuous_vid_pad);
395395

396-
seq.set_toks(ids);
397-
398-
if let Some(ref mut metadata) = paged_attn_metadata {
399-
// Free and then reallocate as appropriate
400-
metadata.block_engine.free_sequence(*seq.id());
401-
metadata.block_engine.allocate(*seq);
402-
}
396+
seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
403397
}
404398

405399
let mut input_ids_searching = Vec::new();

0 commit comments

Comments
 (0)