Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ license = "MIT"
rust-version = "1.82"

[workspace.dependencies]
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "e8209f3" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "e8209f3" }
candle-flash-attn-v3 = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "e8209f3" }
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "e8209f3" }
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "98c0436e" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "98c0436e" }
candle-flash-attn-v3 = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "98c0436e" }
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "98c0436e" }
# candle-core = { path = "../candle/candle-core" }
# candle-nn = { path = "../candle/candle-nn" }
# candle-flash-attn-v3 = { path = "../candle/candle-flash-attn-v3" }
Expand Down
108 changes: 84 additions & 24 deletions mistralrs-core/src/paged_attention/cache_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use std::{
sync::{Arc, Mutex, MutexGuard},
};

use candle_core::{DType, Device, Result, Tensor};
use candle_core::{
from_storage_no_op, DType, Device, MetalStorage, Result, Shape, Storage, Tensor,
};
use mistralrs_paged_attn::{copy_blocks, swap_blocks};

use super::config::ModelConfigLike;
Expand Down Expand Up @@ -70,30 +72,88 @@ impl CacheEngine {
.take(model_config.num_layers())
.map(|x| x.as_ref().unwrap_or(device))
{
let key_blocks = unsafe {
Tensor::empty(
(
cache_config.num_gpu_blocks,
key_block_shape.0,
key_block_shape.1,
key_block_shape.2,
key_block_shape.3,
),
dtype,
device,
)?
let key_blocks = if let Device::Metal(dev) = &device {
#[cfg(feature = "metal")]
{
let elem_count = cache_config.num_gpu_blocks
* key_block_shape.0
* key_block_shape.1
* key_block_shape.2
* key_block_shape.3;
let buffer = dev.new_buffer_private(elem_count, dtype, "k_cache")?;
let storage =
Storage::Metal(MetalStorage::new(buffer, dev.clone(), elem_count, dtype));
from_storage_no_op(
storage,
Shape::from_dims(&[
cache_config.num_gpu_blocks,
key_block_shape.0,
key_block_shape.1,
key_block_shape.2,
key_block_shape.3,
]),
false,
)
}

#[cfg(not(feature = "metal"))]
{
unreachable!()
}
Comment on lines +99 to +102
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider graceful error handling instead of unreachable!().

Using unreachable!() when a Metal device is detected but the "metal" feature is not enabled could cause unexpected panics. Consider returning a proper error instead.

-#[cfg(not(feature = "metal"))]
-{
-    unreachable!()
-}
+#[cfg(not(feature = "metal"))]
+{
+    return Err(candle_core::Error::Msg(
+        "Metal device detected but metal feature not enabled".to_string()
+    ));
+}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
#[cfg(not(feature = "metal"))]
{
unreachable!()
}
#[cfg(not(feature = "metal"))]
{
return Err(candle_core::Error::Msg(
"Metal device detected but metal feature not enabled".to_string()
));
}
🤖 Prompt for AI Agents
In mistralrs-core/src/paged_attention/cache_engine.rs around lines 99 to 102,
replace the unreachable!() macro with graceful error handling by returning a
proper error indicating that the Metal feature is not enabled. This avoids
panics when a Metal device is detected but the "metal" feature is disabled,
improving robustness and error reporting.

} else {
unsafe {
Tensor::empty(
(
cache_config.num_gpu_blocks,
key_block_shape.0,
key_block_shape.1,
key_block_shape.2,
key_block_shape.3,
),
dtype,
device,
)?
}
};
let value_blocks = unsafe {
Tensor::empty(
(
cache_config.num_gpu_blocks,
value_block_shape.0,
value_block_shape.1,
value_block_shape.2,
),
dtype,
device,
)?
let value_blocks = if let Device::Metal(dev) = &device {
#[cfg(feature = "metal")]
{
let elem_count = cache_config.num_gpu_blocks
* value_block_shape.0
* value_block_shape.1
* value_block_shape.2;
let buffer = dev.new_buffer_private(elem_count, dtype, "v_cache")?;
let storage =
Storage::Metal(MetalStorage::new(buffer, dev.clone(), elem_count, dtype));
from_storage_no_op(
storage,
Shape::from_dims(&[
cache_config.num_gpu_blocks,
value_block_shape.0,
value_block_shape.1,
value_block_shape.2,
]),
false,
)
}

#[cfg(not(feature = "metal"))]
{
unreachable!()
}
} else {
unsafe {
Tensor::empty(
(
cache_config.num_gpu_blocks,
value_block_shape.0,
value_block_shape.1,
value_block_shape.2,
),
dtype,
device,
)?
}
};
gpu_cache.push((key_blocks, value_blocks));
}
Expand Down
Loading