Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions mistralrs-core/src/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,15 @@ fn supports_attn_softmax() -> Result<bool> {
Ok(true)
}

/// Not *really* sure why this is necessary but it is.
fn maybe_synchronize(device: &Device) -> Result<()> {
// If less that 4 GB available, synchronize
if MemoryUsage.get_memory_available(device)? < 4 * 1024 * (1024 * 1024) {
device.synchronize()?;
}
Ok(())
}

/// Computes softmax(QK^T*sqrt(d_k))V
fn naive_sdpa(
q: &Tensor,
Expand All @@ -210,10 +219,7 @@ fn naive_sdpa(
mask: Option<&Tensor>,
sdpa_params: &SdpaParams,
) -> Result<Tensor> {
// If less that 4 GB available, synchronize
if MemoryUsage.get_memory_available(q.device())? < 4 * 1024 * (1024 * 1024) {
q.device().synchronize()?;
}
maybe_synchronize(q.device())?;

// Use faster softmax if mask is rank 2 or it's rank 3
if mask.is_some_and(|mask| mask.rank() == 2 || mask.rank() == 3) && supports_attn_softmax()? {
Expand Down Expand Up @@ -316,13 +322,18 @@ impl Sdpa {

let k = repeat_kv(k.clone(), sdpa_params.n_kv_groups)?;
let v = repeat_kv(v.clone(), sdpa_params.n_kv_groups)?;
return naive_sdpa(q, &k, &v, mask, sdpa_params);

if mask.is_some_and(|x| x.rank() == 2) {
return naive_sdpa(q, &k, &v, mask, sdpa_params);
}

// TODO: bench?
#[allow(unused)]
if let (Device::Cuda(_), Some(cublaslt)) = (q.device(), *CUBLASLT_HANDLE.lock().unwrap()) {
#[cfg(feature = "cuda")]
{
maybe_synchronize(q.device())?;

// cuBLASLt batch matmul implementation requires inputs to be dims3
let k = k.flatten(0, 1)?;
let q = q.flatten(0, 1)?;
Expand Down
20 changes: 16 additions & 4 deletions mistralrs-core/src/pipeline/loaders/vision_loaders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2159,7 +2159,10 @@ impl DeviceMappedModelLoader for Idefics3Loader {
layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
};

post_layernorm + patch_embedding + position_embedding + layer_elems
post_layernorm
+ patch_embedding
+ position_embedding
+ layer_elems * cfg.num_hidden_layers
};

let elems = text_elems + connector_elems + vision_transformer;
Expand Down Expand Up @@ -2428,7 +2431,10 @@ impl DeviceMappedModelLoader for MiniCpmOLoader {
layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
};

post_layernorm + patch_embedding + position_embedding + layer_elems
post_layernorm
+ patch_embedding
+ position_embedding
+ layer_elems * cfg.num_hidden_layers
};

let elems = text_elems + vision_transformer;
Expand Down Expand Up @@ -2733,7 +2739,10 @@ impl DeviceMappedModelLoader for Phi4MMLoader {
layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
};

post_layernorm + patch_embedding + position_embedding + layer_elems
post_layernorm
+ patch_embedding
+ position_embedding
+ layer_elems * cfg.num_hidden_layers
};

proj + glb_gn + sub_gn + vision_transformer
Expand Down Expand Up @@ -3299,7 +3308,10 @@ impl DeviceMappedModelLoader for Gemma3Loader {
layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
};

post_layernorm + patch_embedding + position_embedding + layer_elems
post_layernorm
+ patch_embedding
+ position_embedding
+ layer_elems * cfg.num_hidden_layers
} else {
0
};
Expand Down
119 changes: 50 additions & 69 deletions mistralrs-core/src/vision_models/gemma3/inputs_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,62 +108,13 @@ impl InputsProcessor for Gemma3ImageProcessor {
))));
};

let text_models_inputs_processor::InnerInputProcessorOutput {
inputs:
text_models_inputs_processor::InputMetadata {
input,
positions,
context_lens,
position_ids,
paged_attn_meta,
flash_meta,
},
seq_indices,
} = if is_prompt {
get_prompt_input(
input_seqs
.iter()
.map(|seq| seq.get_toks().to_vec())
.collect::<Vec<_>>(),
input_seqs,
device,
last_n_context_len,
return_raw_logits,
paged_attn_metadata.as_mut(),
None, // TODO: evaluate if it is possible to batch this
mapper,
)
.nth(0)
.unwrap()
.unwrap()
} else {
get_completion_input(
input_seqs
.iter()
.map(|seq| seq.get_toks().to_vec())
.collect::<Vec<_>>(),
input_seqs,
device,
no_kv_cache,
last_n_context_len,
return_raw_logits,
paged_attn_metadata.as_mut(),
None, // TODO: evaluate if it is possible to batch this
mapper,
)
.nth(0)
.unwrap()
.unwrap()
};

let config = other_config.expect("Need a PreProcessorConfig config.");
let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");

let has_images = input_seqs.iter().all(|seq| seq.has_images());

let (new_input, pixel_values) = if has_images {
let pixel_values = if has_images {
let mut pixel_values_accum = Vec::new();
let mut all_ids = Vec::new();
let re = Regex::new(BOI_TOKEN).unwrap();
for seq in input_seqs.iter_mut() {
let PreprocessedImages {
Expand Down Expand Up @@ -228,30 +179,60 @@ impl InputsProcessor for Gemma3ImageProcessor {
.expect("Detokenization failed!");

let ids = toks.get_ids().to_vec();
all_ids.push(ids.clone());

seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
}

let mut all_ids_new = Vec::new();
let max_len = all_ids.iter().map(|ids| ids.len()).max().unwrap();
for ids in all_ids {
let pad = max_len - ids.len();
all_ids_new
.push(Tensor::new([ids, vec![0; pad]].concat(), input.device()).unwrap());
}

(
Some(Tensor::stack(&all_ids_new, 0).unwrap()),
Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
)
Some(Tensor::cat(&pixel_values_accum, 0).unwrap())
} else {
(None, None)
None
};

let input = match new_input {
Some(new_input) => new_input,
None => input,
let text_models_inputs_processor::InnerInputProcessorOutput {
inputs:
text_models_inputs_processor::InputMetadata {
input,
positions,
context_lens,
position_ids,
paged_attn_meta,
flash_meta,
},
seq_indices,
} = if is_prompt {
get_prompt_input(
input_seqs
.iter()
.map(|seq| seq.get_toks().to_vec())
.collect::<Vec<_>>(),
input_seqs,
device,
last_n_context_len,
return_raw_logits,
paged_attn_metadata.as_mut(),
None, // TODO: evaluate if it is possible to batch this
mapper,
)
.nth(0)
.unwrap()
.unwrap()
} else {
get_completion_input(
input_seqs
.iter()
.map(|seq| seq.get_toks().to_vec())
.collect::<Vec<_>>(),
input_seqs,
device,
no_kv_cache,
last_n_context_len,
return_raw_logits,
paged_attn_metadata.as_mut(),
None, // TODO: evaluate if it is possible to batch this
mapper,
)
.nth(0)
.unwrap()
.unwrap()
};

let inputs: Box<dyn Any> = Box::new(ModelInputs {
Expand Down Expand Up @@ -401,7 +382,7 @@ impl ImagePreProcessor for Gemma3ImageProcessor {

for image in images.iter_mut() {
// Convert to rgb
if config.do_convert_rgb.is_some_and(|x| x) {
if do_convert_rgb {
*image = DynamicImage::ImageRgb8(image.to_rgb8());
}
}
Expand Down
112 changes: 48 additions & 64 deletions mistralrs-core/src/vision_models/idefics3/inputs_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,62 +135,14 @@ impl InputsProcessor for Idefics3ImageProcessor {
))));
};

let text_models_inputs_processor::InnerInputProcessorOutput {
inputs:
text_models_inputs_processor::InputMetadata {
input,
positions,
context_lens,
position_ids,
paged_attn_meta,
flash_meta,
},
seq_indices,
} = if is_prompt {
get_prompt_input(
input_seqs
.iter()
.map(|seq| seq.get_toks().to_vec())
.collect::<Vec<_>>(),
input_seqs,
device,
last_n_context_len,
return_raw_logits,
paged_attn_metadata.as_mut(),
None, // TODO: evaluate if it is possible to batch this
mapper,
)
.nth(0)
.unwrap()
.unwrap()
} else {
get_completion_input(
input_seqs
.iter()
.map(|seq| seq.get_toks().to_vec())
.collect::<Vec<_>>(),
input_seqs,
device,
no_kv_cache,
last_n_context_len,
return_raw_logits,
paged_attn_metadata.as_mut(),
None, // TODO: evaluate if it is possible to batch this
mapper,
)
.nth(0)
.unwrap()
.unwrap()
};
let config = other_config.expect("Need a PreProcessorConfig config.");
let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");

let has_images = input_seqs.iter().all(|seq| seq.has_images());

let (new_input, pixel_values, pixel_attention_mask) = if has_images {
let (pixel_values, pixel_attention_mask) = if has_images {
let mut pixel_values_accum = Vec::new();
let mut pixel_attention_mask_accum = Vec::new();
let mut all_ids = Vec::new();
for seq in input_seqs.iter_mut() {
let PreprocessedImages {
pixel_values,
Expand Down Expand Up @@ -248,31 +200,63 @@ impl InputsProcessor for Idefics3ImageProcessor {
.expect("Detokenization failed!");

let ids = toks.get_ids().to_vec();
all_ids.push(ids.clone());

seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
}

let mut all_ids_new = Vec::new();
let max_len = all_ids.iter().map(|ids| ids.len()).max().unwrap();
for ids in all_ids {
let pad = max_len - ids.len();
all_ids_new
.push(Tensor::new([ids, vec![0; pad]].concat(), input.device()).unwrap());
}

(
Some(Tensor::stack(&all_ids_new, 0).unwrap()),
Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
Some(Tensor::cat(&pixel_attention_mask_accum, 0).unwrap()),
)
} else {
(None, None, None)
(None, None)
};

let input = match new_input {
Some(new_input) => new_input,
None => input,
let text_models_inputs_processor::InnerInputProcessorOutput {
inputs:
text_models_inputs_processor::InputMetadata {
input,
positions,
context_lens,
position_ids,
paged_attn_meta,
flash_meta,
},
seq_indices,
} = if is_prompt {
get_prompt_input(
input_seqs
.iter()
.map(|seq| seq.get_toks().to_vec())
.collect::<Vec<_>>(),
input_seqs,
device,
last_n_context_len,
return_raw_logits,
paged_attn_metadata.as_mut(),
None, // TODO: evaluate if it is possible to batch this
mapper,
)
.nth(0)
.unwrap()
.unwrap()
} else {
get_completion_input(
input_seqs
.iter()
.map(|seq| seq.get_toks().to_vec())
.collect::<Vec<_>>(),
input_seqs,
device,
no_kv_cache,
last_n_context_len,
return_raw_logits,
paged_attn_metadata.as_mut(),
None, // TODO: evaluate if it is possible to batch this
mapper,
)
.nth(0)
.unwrap()
.unwrap()
};

let inputs: Box<dyn Any> = Box::new(ModelInputs {
Expand Down
Loading
Loading