Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions mistralrs-paged-attn/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@ fn main() -> Result<()> {
use std::process::Command;

const OTHER_CONTENT: &str = r#"
pub const COPY_BLOCKS_KERNEL: &str =
include_str!(concat!(env!("OUT_DIR"), "/copy_blocks_kernel.ptx"));
pub const PAGEDATTENTION: &str = include_str!(concat!(env!("OUT_DIR"), "/pagedattention.ptx"));
pub const RESHAPE_AND_CACHE_KERNEL: &str =
include_str!(concat!(env!("OUT_DIR"), "/reshape_and_cache_kernel.ptx"));
pub const USE_FP8: bool = false;

mod backend;
Expand All @@ -25,7 +20,7 @@ pub use backend::{copy_blocks, paged_attention, reshape_and_cache, swap_blocks};
"#;

println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-changed=src/cuda/pagedattention.cu");
println!("cargo:rerun-if-changed=src/cuda/pagedattention.cuh");
println!("cargo:rerun-if-changed=src/cuda/copy_blocks_kernel.cu");
println!("cargo:rerun-if-changed=src/cuda/reshape_and_cache_kernel.cu");

Expand Down Expand Up @@ -102,9 +97,6 @@ pub use backend::{copy_blocks, paged_attention, reshape_and_cache, swap_blocks};
};
builder.build_lib(out_file);

let bindings = builder.build_ptx().unwrap();
bindings.write("src/cuda/mod.rs").unwrap();

let kernel_dir = PathBuf::from("../mistralrs-paged-attn");
let absolute_kernel_dir = std::fs::canonicalize(kernel_dir).unwrap();

Expand Down
94 changes: 46 additions & 48 deletions mistralrs-paged-attn/src/cuda/backend/cache.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
use std::{collections::HashMap, iter::zip, ptr::NonNull};
use std::{collections::HashMap, iter::zip};

use crate::cuda::backend::get_or_load_func;

use candle_core::cuda::cudarc::driver::LaunchAsync;
use crate::cuda::ffi::{copy_blocks_bf16, copy_blocks_f16, copy_blocks_f32};
use candle_core::cuda::WrapErr;
use candle_core::cuda_backend::CudaStorageSlice;
use candle_core::Result;
use candle_core::{
cuda_backend::cudarc::driver::{CudaSlice, DevicePtr, LaunchConfig},
Device, IndexOp, Storage, Tensor,
cuda_backend::cudarc::driver::{CudaSlice, DevicePtr},
DType, Device, IndexOp, Storage, Tensor,
};

use super::{Conjoined, COPY_BLOCKS_KERNEL_NAME};
use crate::COPY_BLOCKS_KERNEL;

pub fn copy_blocks(
key_caches: Vec<&mut Tensor>,
value_caches: Vec<&mut Tensor>,
Expand Down Expand Up @@ -46,6 +41,8 @@ pub fn copy_blocks(
key_cache_ptrs.reserve_exact(num_layers as usize);
let mut value_cache_ptrs = Vec::new();
value_cache_ptrs.reserve_exact(num_layers as usize);
let mut dtype = DType::F32;

for (key_cache, value_cache) in zip(&key_caches, &value_caches) {
key_cache.to_device(cache_dev)?;
value_cache.to_device(cache_dev)?;
Expand Down Expand Up @@ -74,11 +71,13 @@ pub fn copy_blocks(
(CudaStorageSlice::BF16(slice_key), CudaStorageSlice::BF16(slice_value)) => {
let ptr_key = *slice_key.slice(0..).device_ptr();
let ptr_value = *slice_value.slice(0..).device_ptr();
dtype = DType::BF16;
(ptr_key, ptr_value)
}
(CudaStorageSlice::F16(slice_key), CudaStorageSlice::F16(slice_value)) => {
let ptr_key = *slice_key.slice(0..).device_ptr();
let ptr_value = *slice_value.slice(0..).device_ptr();
dtype = DType::F16;
(ptr_key, ptr_value)
}
(CudaStorageSlice::F32(slice_key), CudaStorageSlice::F32(slice_value)) => {
Expand All @@ -102,19 +101,10 @@ pub fn copy_blocks(
}
}
let num_pairs: u32 = (block_mapping_vec.len() / 2).try_into().unwrap();
let block_mapping_ptr = Conjoined::new(
NonNull::new(block_mapping_vec.as_mut_ptr()).unwrap(),
&mut block_mapping_vec,
);

let key_cache_ptr = Conjoined::new(
NonNull::new(key_cache_ptrs.as_mut_ptr()).unwrap(),
&mut key_cache_ptrs,
);
let value_cache_ptr = Conjoined::new(
NonNull::new(value_cache_ptrs.as_mut_ptr()).unwrap(),
&mut value_cache_ptrs,
);
let key_cache_ptr = key_cache_ptrs.as_mut_ptr() as *mut core::ffi::c_void;
let value_cache_ptr = value_cache_ptrs.as_mut_ptr() as *mut core::ffi::c_void;
let block_mapping_ptr = block_mapping_vec.as_mut_ptr() as *const core::ffi::c_void;

let numel_per_block: u32 = key_caches
.first()
Expand All @@ -126,34 +116,42 @@ pub fn copy_blocks(
.product::<usize>()
.try_into()
.unwrap();
let launch_conf = LaunchConfig {
grid_dim: (num_layers, num_pairs, 1u32),
block_dim: (numel_per_block.min(1024), 1u32, 1u32),
shared_mem_bytes: 0,
};
let stream = dev.fork_default_stream().w()?;

let kernel = get_or_load_func(
COPY_BLOCKS_KERNEL,
COPY_BLOCKS_KERNEL_NAME,
key_caches.first().unwrap().dtype(),
None,
dev,
)?;

unsafe {
kernel
.launch_on_stream(
&stream,
launch_conf,
(
key_cache_ptr,
value_cache_ptr,
block_mapping_ptr,
numel_per_block as i32,
),
)
.w()?;
match dtype {
candle_core::DType::BF16 => unsafe {
copy_blocks_bf16(
key_cache_ptr,
value_cache_ptr,
block_mapping_ptr,
num_layers as i32,
num_pairs as i32,
numel_per_block as i32,
*dev.cu_stream() as i64,
);
},
candle_core::DType::F16 => unsafe {
copy_blocks_f16(
key_cache_ptr,
value_cache_ptr,
block_mapping_ptr,
num_layers as i32,
num_pairs as i32,
numel_per_block as i32,
*dev.cu_stream() as i64,
);
},
candle_core::DType::F32 => unsafe {
copy_blocks_f32(
key_cache_ptr,
value_cache_ptr,
block_mapping_ptr,
num_layers as i32,
num_pairs as i32,
numel_per_block as i32,
*dev.cu_stream() as i64,
);
},
_ => {}
}

Ok(())
Expand Down
68 changes: 0 additions & 68 deletions mistralrs-paged-attn/src/cuda/backend/mod.rs
Original file line number Diff line number Diff line change
@@ -1,72 +1,4 @@
mod cache;
mod paged_attention;

use std::{
marker::PhantomData,
ptr::{addr_of, NonNull},
};

use candle_core::{
cuda::cudarc::driver::DeviceRepr, cuda_backend::cudarc::driver::CudaFunction, CudaDevice,
DType, Result,
};

pub use cache::{copy_blocks, swap_blocks};
pub use paged_attention::{paged_attention, reshape_and_cache};

const COPY_BLOCKS_KERNEL_NAME: &str = "copy_blocks_kernel";

pub fn get_or_load_func(
ptx_file: &'static str,
kernel_base: &str,
dtype: DType,
suffix: Option<&str>,
device: &CudaDevice,
) -> Result<CudaFunction> {
let spec = match dtype {
DType::U8 => "_u8",
DType::U32 => "_u32",
DType::I16 => "_i16",
DType::I32 => "_i32",
DType::I64 => "_i64",
DType::BF16 => "_bf16",
DType::F16 => "_f16",
DType::F32 => "_f32",
DType::F64 => "_f64",
DType::F8E4M3 => "_f8_e4m3",
};
let spec = if let Some(suffix) = suffix {
spec.to_owned() + suffix
} else {
spec.to_owned()
};
let kernel = kernel_base.to_owned() + &spec;
device.get_or_load_func(&kernel, ptx_file)
}

#[repr(transparent)]
struct Conjoined<'a, T, R> {
raw: *mut T,
_ref: PhantomData<&'a mut R>,
}

impl<'a, T, R> Conjoined<'a, T, R> {
fn new(raw: NonNull<T>, _ref: &'a mut R) -> Self {
Self {
raw: raw.as_ptr(),
_ref: PhantomData,
}
}
}

/// According to the docs: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gb8f3dc3031b40da29d5f9a7139e52e15
/// Each of the kernel params (*mut c_void) "must point to a region of memory from which the actual kernel parameter will be copied".
/// This means that we must return a pointer to our pointer.
///
/// ## Safety
/// - The returned pointer **must not** outlive the &self reference. Otherwise, a dangling pointer is created.
unsafe impl<T, R> DeviceRepr for Conjoined<'_, T, R> {
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
addr_of!(self.raw) as *mut _
}
}
30 changes: 18 additions & 12 deletions mistralrs-paged-attn/src/cuda/backend/paged_attention.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use crate::cuda::ffi;
use crate::cuda::ffi::{paged_attention_v1, paged_attention_v2};
use crate::cuda::ffi::{
paged_attention_v1_bf16, paged_attention_v1_f16, paged_attention_v1_f32,
paged_attention_v2_bf16, paged_attention_v2_f16, paged_attention_v2_f32,
};
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr;
Expand Down Expand Up @@ -31,13 +34,6 @@ impl PagedAttention {
q_l: &Layout,
) -> Result<(CudaStorage, Shape)> {
let dtype = q.dtype();
let internal_type = match dtype {
DType::F16 => 0,
DType::BF16 => 1,
DType::F32 => 2,
dtype => candle::bail!("dtype {dtype:?} is not supported"),
};

let cache_dtype = match self.key_cache.dtype() {
DType::F16 => 0,
DType::BF16 => 1,
Expand Down Expand Up @@ -233,8 +229,14 @@ impl PagedAttention {
let cl_ptr = *cl.device_ptr() as *const core::ffi::c_int;

if use_v1 {
let paged_attention_v1_func = match dtype {
DType::F16 => paged_attention_v1_f16,
DType::BF16 => paged_attention_v1_bf16,
DType::F32 => paged_attention_v1_f32,
dtype => candle::bail!("dtype {dtype:?} is not supported"),
};
unsafe {
paged_attention_v1(
paged_attention_v1_func(
out_ptr,
q_ptr,
kc_ptr,
Expand All @@ -255,7 +257,6 @@ impl PagedAttention {
kv_block_stride as c_int,
kv_head_stride as c_int,
*dev.cu_stream(),
internal_type,
cache_dtype,
k_scale_ptr,
v_scale_ptr,
Expand All @@ -272,8 +273,14 @@ impl PagedAttention {
let exp_sums_ptr = *exp_sums.device_ptr() as *const f32;
let max_logits_ptr = *max_logits.device_ptr() as *const f32;

let paged_attention_v2_func = match dtype {
DType::F16 => paged_attention_v2_f16,
DType::BF16 => paged_attention_v2_bf16,
DType::F32 => paged_attention_v2_f32,
dtype => candle::bail!("dtype {dtype:?} is not supported"),
};
unsafe {
paged_attention_v2(
paged_attention_v2_func(
out_ptr,
exp_sums_ptr,
max_logits_ptr,
Expand All @@ -297,7 +304,6 @@ impl PagedAttention {
kv_block_stride as c_int,
kv_head_stride as c_int,
*dev.cu_stream(),
internal_type,
cache_dtype,
k_scale_ptr,
v_scale_ptr,
Expand Down
Loading
Loading