Skip to content

Commit d38a7e1

Browse files
authored
Split Marlin and Paged Attention kernels for faster build (#1525)
* Split Marlin and Paged Attention kernels for faster build * Typo fix
1 parent d025ebd commit d38a7e1

22 files changed

+1019
-835
lines changed

mistralrs-paged-attn/build.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,6 @@ fn main() -> Result<()> {
1111
use std::process::Command;
1212

1313
const OTHER_CONTENT: &str = r#"
14-
pub const COPY_BLOCKS_KERNEL: &str =
15-
include_str!(concat!(env!("OUT_DIR"), "/copy_blocks_kernel.ptx"));
16-
pub const PAGEDATTENTION: &str = include_str!(concat!(env!("OUT_DIR"), "/pagedattention.ptx"));
17-
pub const RESHAPE_AND_CACHE_KERNEL: &str =
18-
include_str!(concat!(env!("OUT_DIR"), "/reshape_and_cache_kernel.ptx"));
1914
pub const USE_FP8: bool = false;
2015
2116
mod backend;
@@ -25,7 +20,7 @@ pub use backend::{copy_blocks, paged_attention, reshape_and_cache, swap_blocks};
2520
"#;
2621

2722
println!("cargo:rerun-if-changed=build.rs");
28-
println!("cargo:rerun-if-changed=src/cuda/pagedattention.cu");
23+
println!("cargo:rerun-if-changed=src/cuda/pagedattention.cuh");
2924
println!("cargo:rerun-if-changed=src/cuda/copy_blocks_kernel.cu");
3025
println!("cargo:rerun-if-changed=src/cuda/reshape_and_cache_kernel.cu");
3126

@@ -102,9 +97,6 @@ pub use backend::{copy_blocks, paged_attention, reshape_and_cache, swap_blocks};
10297
};
10398
builder.build_lib(out_file);
10499

105-
let bindings = builder.build_ptx().unwrap();
106-
bindings.write("src/cuda/mod.rs").unwrap();
107-
108100
let kernel_dir = PathBuf::from("../mistralrs-paged-attn");
109101
let absolute_kernel_dir = std::fs::canonicalize(kernel_dir).unwrap();
110102

mistralrs-paged-attn/src/cuda/backend/cache.rs

Lines changed: 46 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
1-
use std::{collections::HashMap, iter::zip, ptr::NonNull};
1+
use std::{collections::HashMap, iter::zip};
22

3-
use crate::cuda::backend::get_or_load_func;
4-
5-
use candle_core::cuda::cudarc::driver::LaunchAsync;
3+
use crate::cuda::ffi::{copy_blocks_bf16, copy_blocks_f16, copy_blocks_f32};
64
use candle_core::cuda::WrapErr;
75
use candle_core::cuda_backend::CudaStorageSlice;
86
use candle_core::Result;
97
use candle_core::{
10-
cuda_backend::cudarc::driver::{CudaSlice, DevicePtr, LaunchConfig},
11-
Device, IndexOp, Storage, Tensor,
8+
cuda_backend::cudarc::driver::{CudaSlice, DevicePtr},
9+
DType, Device, IndexOp, Storage, Tensor,
1210
};
1311

14-
use super::{Conjoined, COPY_BLOCKS_KERNEL_NAME};
15-
use crate::COPY_BLOCKS_KERNEL;
16-
1712
pub fn copy_blocks(
1813
key_caches: Vec<&mut Tensor>,
1914
value_caches: Vec<&mut Tensor>,
@@ -46,6 +41,8 @@ pub fn copy_blocks(
4641
key_cache_ptrs.reserve_exact(num_layers as usize);
4742
let mut value_cache_ptrs = Vec::new();
4843
value_cache_ptrs.reserve_exact(num_layers as usize);
44+
let mut dtype = DType::F32;
45+
4946
for (key_cache, value_cache) in zip(&key_caches, &value_caches) {
5047
key_cache.to_device(cache_dev)?;
5148
value_cache.to_device(cache_dev)?;
@@ -74,11 +71,13 @@ pub fn copy_blocks(
7471
(CudaStorageSlice::BF16(slice_key), CudaStorageSlice::BF16(slice_value)) => {
7572
let ptr_key = *slice_key.slice(0..).device_ptr();
7673
let ptr_value = *slice_value.slice(0..).device_ptr();
74+
dtype = DType::BF16;
7775
(ptr_key, ptr_value)
7876
}
7977
(CudaStorageSlice::F16(slice_key), CudaStorageSlice::F16(slice_value)) => {
8078
let ptr_key = *slice_key.slice(0..).device_ptr();
8179
let ptr_value = *slice_value.slice(0..).device_ptr();
80+
dtype = DType::F16;
8281
(ptr_key, ptr_value)
8382
}
8483
(CudaStorageSlice::F32(slice_key), CudaStorageSlice::F32(slice_value)) => {
@@ -102,19 +101,10 @@ pub fn copy_blocks(
102101
}
103102
}
104103
let num_pairs: u32 = (block_mapping_vec.len() / 2).try_into().unwrap();
105-
let block_mapping_ptr = Conjoined::new(
106-
NonNull::new(block_mapping_vec.as_mut_ptr()).unwrap(),
107-
&mut block_mapping_vec,
108-
);
109104

110-
let key_cache_ptr = Conjoined::new(
111-
NonNull::new(key_cache_ptrs.as_mut_ptr()).unwrap(),
112-
&mut key_cache_ptrs,
113-
);
114-
let value_cache_ptr = Conjoined::new(
115-
NonNull::new(value_cache_ptrs.as_mut_ptr()).unwrap(),
116-
&mut value_cache_ptrs,
117-
);
105+
let key_cache_ptr = key_cache_ptrs.as_mut_ptr() as *mut core::ffi::c_void;
106+
let value_cache_ptr = value_cache_ptrs.as_mut_ptr() as *mut core::ffi::c_void;
107+
let block_mapping_ptr = block_mapping_vec.as_mut_ptr() as *const core::ffi::c_void;
118108

119109
let numel_per_block: u32 = key_caches
120110
.first()
@@ -126,34 +116,42 @@ pub fn copy_blocks(
126116
.product::<usize>()
127117
.try_into()
128118
.unwrap();
129-
let launch_conf = LaunchConfig {
130-
grid_dim: (num_layers, num_pairs, 1u32),
131-
block_dim: (numel_per_block.min(1024), 1u32, 1u32),
132-
shared_mem_bytes: 0,
133-
};
134-
let stream = dev.fork_default_stream().w()?;
135-
136-
let kernel = get_or_load_func(
137-
COPY_BLOCKS_KERNEL,
138-
COPY_BLOCKS_KERNEL_NAME,
139-
key_caches.first().unwrap().dtype(),
140-
None,
141-
dev,
142-
)?;
143119

144-
unsafe {
145-
kernel
146-
.launch_on_stream(
147-
&stream,
148-
launch_conf,
149-
(
150-
key_cache_ptr,
151-
value_cache_ptr,
152-
block_mapping_ptr,
153-
numel_per_block as i32,
154-
),
155-
)
156-
.w()?;
120+
match dtype {
121+
candle_core::DType::BF16 => unsafe {
122+
copy_blocks_bf16(
123+
key_cache_ptr,
124+
value_cache_ptr,
125+
block_mapping_ptr,
126+
num_layers as i32,
127+
num_pairs as i32,
128+
numel_per_block as i32,
129+
*dev.cu_stream() as i64,
130+
);
131+
},
132+
candle_core::DType::F16 => unsafe {
133+
copy_blocks_f16(
134+
key_cache_ptr,
135+
value_cache_ptr,
136+
block_mapping_ptr,
137+
num_layers as i32,
138+
num_pairs as i32,
139+
numel_per_block as i32,
140+
*dev.cu_stream() as i64,
141+
);
142+
},
143+
candle_core::DType::F32 => unsafe {
144+
copy_blocks_f32(
145+
key_cache_ptr,
146+
value_cache_ptr,
147+
block_mapping_ptr,
148+
num_layers as i32,
149+
num_pairs as i32,
150+
numel_per_block as i32,
151+
*dev.cu_stream() as i64,
152+
);
153+
},
154+
_ => {}
157155
}
158156

159157
Ok(())
Lines changed: 0 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,4 @@
11
mod cache;
22
mod paged_attention;
3-
4-
use std::{
5-
marker::PhantomData,
6-
ptr::{addr_of, NonNull},
7-
};
8-
9-
use candle_core::{
10-
cuda::cudarc::driver::DeviceRepr, cuda_backend::cudarc::driver::CudaFunction, CudaDevice,
11-
DType, Result,
12-
};
13-
143
pub use cache::{copy_blocks, swap_blocks};
154
pub use paged_attention::{paged_attention, reshape_and_cache};
16-
17-
const COPY_BLOCKS_KERNEL_NAME: &str = "copy_blocks_kernel";
18-
19-
pub fn get_or_load_func(
20-
ptx_file: &'static str,
21-
kernel_base: &str,
22-
dtype: DType,
23-
suffix: Option<&str>,
24-
device: &CudaDevice,
25-
) -> Result<CudaFunction> {
26-
let spec = match dtype {
27-
DType::U8 => "_u8",
28-
DType::U32 => "_u32",
29-
DType::I16 => "_i16",
30-
DType::I32 => "_i32",
31-
DType::I64 => "_i64",
32-
DType::BF16 => "_bf16",
33-
DType::F16 => "_f16",
34-
DType::F32 => "_f32",
35-
DType::F64 => "_f64",
36-
DType::F8E4M3 => "_f8_e4m3",
37-
};
38-
let spec = if let Some(suffix) = suffix {
39-
spec.to_owned() + suffix
40-
} else {
41-
spec.to_owned()
42-
};
43-
let kernel = kernel_base.to_owned() + &spec;
44-
device.get_or_load_func(&kernel, ptx_file)
45-
}
46-
47-
#[repr(transparent)]
48-
struct Conjoined<'a, T, R> {
49-
raw: *mut T,
50-
_ref: PhantomData<&'a mut R>,
51-
}
52-
53-
impl<'a, T, R> Conjoined<'a, T, R> {
54-
fn new(raw: NonNull<T>, _ref: &'a mut R) -> Self {
55-
Self {
56-
raw: raw.as_ptr(),
57-
_ref: PhantomData,
58-
}
59-
}
60-
}
61-
62-
/// According to the docs: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gb8f3dc3031b40da29d5f9a7139e52e15
63-
/// Each of the kernel params (*mut c_void) "must point to a region of memory from which the actual kernel parameter will be copied".
64-
/// This means that we must return a pointer to our pointer.
65-
///
66-
/// ## Safety
67-
/// - The returned pointer **must not** outlive the &self reference. Otherwise, a dangling pointer is created.
68-
unsafe impl<T, R> DeviceRepr for Conjoined<'_, T, R> {
69-
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
70-
addr_of!(self.raw) as *mut _
71-
}
72-
}

mistralrs-paged-attn/src/cuda/backend/paged_attention.rs

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
use crate::cuda::ffi;
2-
use crate::cuda::ffi::{paged_attention_v1, paged_attention_v2};
2+
use crate::cuda::ffi::{
3+
paged_attention_v1_bf16, paged_attention_v1_f16, paged_attention_v1_f32,
4+
paged_attention_v2_bf16, paged_attention_v2_f16, paged_attention_v2_f32,
5+
};
36
use candle::backend::BackendStorage;
47
use candle::cuda_backend::cudarc::driver::DevicePtr;
58
use candle::cuda_backend::WrapErr;
@@ -31,13 +34,6 @@ impl PagedAttention {
3134
q_l: &Layout,
3235
) -> Result<(CudaStorage, Shape)> {
3336
let dtype = q.dtype();
34-
let internal_type = match dtype {
35-
DType::F16 => 0,
36-
DType::BF16 => 1,
37-
DType::F32 => 2,
38-
dtype => candle::bail!("dtype {dtype:?} is not supported"),
39-
};
40-
4137
let cache_dtype = match self.key_cache.dtype() {
4238
DType::F16 => 0,
4339
DType::BF16 => 1,
@@ -233,8 +229,14 @@ impl PagedAttention {
233229
let cl_ptr = *cl.device_ptr() as *const core::ffi::c_int;
234230

235231
if use_v1 {
232+
let paged_attention_v1_func = match dtype {
233+
DType::F16 => paged_attention_v1_f16,
234+
DType::BF16 => paged_attention_v1_bf16,
235+
DType::F32 => paged_attention_v1_f32,
236+
dtype => candle::bail!("dtype {dtype:?} is not supported"),
237+
};
236238
unsafe {
237-
paged_attention_v1(
239+
paged_attention_v1_func(
238240
out_ptr,
239241
q_ptr,
240242
kc_ptr,
@@ -255,7 +257,6 @@ impl PagedAttention {
255257
kv_block_stride as c_int,
256258
kv_head_stride as c_int,
257259
*dev.cu_stream(),
258-
internal_type,
259260
cache_dtype,
260261
k_scale_ptr,
261262
v_scale_ptr,
@@ -272,8 +273,14 @@ impl PagedAttention {
272273
let exp_sums_ptr = *exp_sums.device_ptr() as *const f32;
273274
let max_logits_ptr = *max_logits.device_ptr() as *const f32;
274275

276+
let paged_attention_v2_func = match dtype {
277+
DType::F16 => paged_attention_v2_f16,
278+
DType::BF16 => paged_attention_v2_bf16,
279+
DType::F32 => paged_attention_v2_f32,
280+
dtype => candle::bail!("dtype {dtype:?} is not supported"),
281+
};
275282
unsafe {
276-
paged_attention_v2(
283+
paged_attention_v2_func(
277284
out_ptr,
278285
exp_sums_ptr,
279286
max_logits_ptr,
@@ -297,7 +304,6 @@ impl PagedAttention {
297304
kv_block_stride as c_int,
298305
kv_head_stride as c_int,
299306
*dev.cu_stream(),
300-
internal_type,
301307
cache_dtype,
302308
k_scale_ptr,
303309
v_scale_ptr,

0 commit comments

Comments
 (0)