Skip to content

Commit 7c5fcd0

Browse files
committed
Fix metal large paged attn kv cache
1 parent f38567a commit 7c5fcd0

File tree

3 files changed

+94
-34
lines changed

3 files changed

+94
-34
lines changed

Cargo.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ license = "MIT"
2727
rust-version = "1.82"
2828

2929
[workspace.dependencies]
30-
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "e8209f3" }
31-
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "e8209f3" }
32-
candle-flash-attn-v3 = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "e8209f3" }
33-
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "e8209f3" }
30+
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "98c0436e" }
31+
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "98c0436e" }
32+
candle-flash-attn-v3 = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "98c0436e" }
33+
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "98c0436e" }
3434
# candle-core = { path = "../candle/candle-core" }
3535
# candle-nn = { path = "../candle/candle-nn" }
3636
# candle-flash-attn-v3 = { path = "../candle/candle-flash-attn-v3" }

mistralrs-core/src/paged_attention/cache_engine.rs

Lines changed: 84 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ use std::{
33
sync::{Arc, Mutex, MutexGuard},
44
};
55

6-
use candle_core::{DType, Device, Result, Tensor};
6+
use candle_core::{
7+
from_storage_no_op, DType, Device, MetalStorage, Result, Shape, Storage, Tensor,
8+
};
79
use mistralrs_paged_attn::{copy_blocks, swap_blocks};
810

911
use super::config::ModelConfigLike;
@@ -70,30 +72,88 @@ impl CacheEngine {
7072
.take(model_config.num_layers())
7173
.map(|x| x.as_ref().unwrap_or(device))
7274
{
73-
let key_blocks = unsafe {
74-
Tensor::empty(
75-
(
76-
cache_config.num_gpu_blocks,
77-
key_block_shape.0,
78-
key_block_shape.1,
79-
key_block_shape.2,
80-
key_block_shape.3,
81-
),
82-
dtype,
83-
device,
84-
)?
75+
let key_blocks = if let Device::Metal(dev) = &device {
76+
#[cfg(feature = "metal")]
77+
{
78+
let elem_count = cache_config.num_gpu_blocks
79+
* key_block_shape.0
80+
* key_block_shape.1
81+
* key_block_shape.2
82+
* key_block_shape.3;
83+
let buffer = dev.new_buffer_private(elem_count, dtype, "k_cache")?;
84+
let storage =
85+
Storage::Metal(MetalStorage::new(buffer, dev.clone(), elem_count, dtype));
86+
from_storage_no_op(
87+
storage,
88+
Shape::from_dims(&[
89+
cache_config.num_gpu_blocks,
90+
key_block_shape.0,
91+
key_block_shape.1,
92+
key_block_shape.2,
93+
key_block_shape.3,
94+
]),
95+
false,
96+
)
97+
}
98+
99+
#[cfg(not(feature = "metal"))]
100+
{
101+
unreachable!()
102+
}
103+
} else {
104+
unsafe {
105+
Tensor::empty(
106+
(
107+
cache_config.num_gpu_blocks,
108+
key_block_shape.0,
109+
key_block_shape.1,
110+
key_block_shape.2,
111+
key_block_shape.3,
112+
),
113+
dtype,
114+
device,
115+
)?
116+
}
85117
};
86-
let value_blocks = unsafe {
87-
Tensor::empty(
88-
(
89-
cache_config.num_gpu_blocks,
90-
value_block_shape.0,
91-
value_block_shape.1,
92-
value_block_shape.2,
93-
),
94-
dtype,
95-
device,
96-
)?
118+
let value_blocks = if let Device::Metal(dev) = &device {
119+
#[cfg(feature = "metal")]
120+
{
121+
let elem_count = cache_config.num_gpu_blocks
122+
* value_block_shape.0
123+
* value_block_shape.1
124+
* value_block_shape.2;
125+
let buffer = dev.new_buffer_private(elem_count, dtype, "v_cache")?;
126+
let storage =
127+
Storage::Metal(MetalStorage::new(buffer, dev.clone(), elem_count, dtype));
128+
from_storage_no_op(
129+
storage,
130+
Shape::from_dims(&[
131+
cache_config.num_gpu_blocks,
132+
value_block_shape.0,
133+
value_block_shape.1,
134+
value_block_shape.2,
135+
]),
136+
false,
137+
)
138+
}
139+
140+
#[cfg(not(feature = "metal"))]
141+
{
142+
unreachable!()
143+
}
144+
} else {
145+
unsafe {
146+
Tensor::empty(
147+
(
148+
cache_config.num_gpu_blocks,
149+
value_block_shape.0,
150+
value_block_shape.1,
151+
value_block_shape.2,
152+
),
153+
dtype,
154+
device,
155+
)?
156+
}
97157
};
98158
gpu_cache.push((key_blocks, value_blocks));
99159
}

0 commit comments

Comments
 (0)