diff --git a/mistralrs-core/src/layers.rs b/mistralrs-core/src/layers.rs index 89c6622709..d8a093566d 100644 --- a/mistralrs-core/src/layers.rs +++ b/mistralrs-core/src/layers.rs @@ -513,7 +513,6 @@ impl PhiRotaryEmbedding { ) -> Result<(Tensor, Tensor)> { let (sin, cos) = self.get_long_or_short_sin_cos(position_ids); let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; - let all_same = seqlen_offsets.iter().all(|&x| x == seqlen_offsets[0]); let rot_dim = cos.dim(D::Minus1)? * 2; @@ -525,7 +524,7 @@ impl PhiRotaryEmbedding { let k_rot = k.narrow(D::Minus1, 0, rot_dim)?; let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?; - let (q_rot, k_rot) = if all_same { + let (q_rot, k_rot) = if seqlen_offsets.len() == 1 { let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?; let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?; let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?; @@ -559,7 +558,7 @@ impl PhiRotaryEmbedding { Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?, Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?, )) - } else if all_same { + } else if seqlen_offsets.len() == 1 { let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?; let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?; let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; @@ -1106,8 +1105,8 @@ impl DeepSeekV2RotaryEmbedding { seqlen_offsets: &[usize], ) -> Result<(Tensor, Tensor)> { let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; - let all_same = seqlen_offsets.iter().all(|&x| x == seqlen_offsets[0]); - if all_same { + + if seqlen_offsets.len() == 1 { let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?; let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?; let q_embed = candle_nn::rotary_emb::rope_i(&q.contiguous()?, &cos, &sin)?; @@ -1296,8 +1295,7 @@ impl Phi4MMRotaryEmbedding { let k_rot = k.narrow(D::Minus1, 0, rot_dim)?; let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?; - let all_same = seqlen_offsets.iter().all(|&x| x == seqlen_offsets[0]); - let (q_rot, k_rot) = if all_same { + let (q_rot, k_rot) = if seqlen_offsets.len() == 1 { let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?; let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?; let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?; @@ -1577,7 +1575,8 @@ impl RotaryEmbedding { k: &Tensor, seqlen_offsets: &[usize], ) -> Result<(Tensor, Tensor)> { - let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; + let (b_sz, qh, seq_len, n_embd) = q.dims4()?; + let (_b_sz, kh, _seq_len, __n_embd) = k.dims4()?; let rope = if self.is_gpt_neox { candle_nn::rotary_emb::rope @@ -1585,8 +1584,43 @@ impl RotaryEmbedding { candle_nn::rotary_emb::rope_i }; - let all_same = seqlen_offsets.iter().all(|&x| x == seqlen_offsets[0]); - if all_same { + if cfg!(feature = "cuda") { + let (cos, sin) = if seqlen_offsets.len() == 1 { + ( + self.cos.narrow(0, seqlen_offsets[0], seq_len)?, + self.sin.narrow(0, seqlen_offsets[0], seq_len)?, + ) + } else { + let mut cos_s = Vec::new(); + let mut sin_s = Vec::new(); + for offset in seqlen_offsets { + cos_s.push(self.cos.narrow(0, *offset, seq_len)?); + sin_s.push(self.sin.narrow(0, *offset, seq_len)?); + } + (Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?) + }; + + let q_embed = q.transpose(1, 2)?.flatten(0, 1)?; + let k_embed = k.transpose(1, 2)?.flatten(0, 1)?; + mistralrs_quant::rotary::apply_rotary_inplace( + &q_embed, + &k_embed, + &cos, + &sin, + self.is_gpt_neox, + )?; + let mut q = q_embed + .reshape((b_sz, seq_len, qh, n_embd))? + .transpose(1, 2)?; + let mut k = k_embed + .reshape((b_sz, seq_len, kh, n_embd))? + .transpose(1, 2)?; + if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) { + q = q.contiguous()?; + k = k.contiguous()?; + } + Ok((q, k)) + } else if seqlen_offsets.len() == 1 { let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?; let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?; let q_embed = rope(&q.contiguous()?, &cos, &sin)?; diff --git a/mistralrs-quant/build.rs b/mistralrs-quant/build.rs index 9349789809..e52341f338 100644 --- a/mistralrs-quant/build.rs +++ b/mistralrs-quant/build.rs @@ -85,6 +85,7 @@ fn main() { "kernels/hqq/hqq.cu", "kernels/ops/ops.cu", "kernels/bitsandbytes/dequant.cu", + "kernels/rotary/rotary.cu", ]; if cc_over_800 { lib_files.push("kernels/marlin/marlin_kernel.cu"); diff --git a/mistralrs-quant/kernels/rotary/cuda_compat.h b/mistralrs-quant/kernels/rotary/cuda_compat.h new file mode 100644 index 0000000000..ed3ebe7dc4 --- /dev/null +++ b/mistralrs-quant/kernels/rotary/cuda_compat.h @@ -0,0 +1,27 @@ +#pragma once + +#ifndef USE_ROCM + #define VLLM_LDG(arg) __ldg(arg) +#else + #define VLLM_LDG(arg) *(arg) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) +#else + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane) +#else + #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) +#endif + +#ifndef USE_ROCM + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) +#else + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) +#endif diff --git a/mistralrs-quant/kernels/rotary/rotary.cu b/mistralrs-quant/kernels/rotary/rotary.cu new file mode 100644 index 0000000000..9e65ee226c --- /dev/null +++ b/mistralrs-quant/kernels/rotary/rotary.cu @@ -0,0 +1,131 @@ +#include +#include +#include + +#include "cuda_compat.h" + +namespace vllm { + +template +inline __device__ void apply_rotary_embedding( + scalar_t* __restrict__ arr, + const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, + int rot_offset, + int rot_dim) +{ + int x_index, y_index; + scalar_t cos, sin; + if (IS_NEOX) { + // GPT-NeoX style rotary embedding. + x_index = rot_offset; + y_index = rot_dim + rot_offset; + cos = VLLM_LDG(cos_ptr + x_index); + sin = VLLM_LDG(sin_ptr + x_index); + } else { + // GPT-J style rotary embedding. + x_index = 2 * rot_offset; + y_index = 2 * rot_offset + 1; + cos = VLLM_LDG(cos_ptr + x_index / 2); + sin = VLLM_LDG(sin_ptr + x_index / 2); + } + + const scalar_t x = arr[x_index]; + const scalar_t y = arr[y_index]; + arr[x_index] = x * cos - y * sin; + arr[y_index] = y * cos + x * sin; +} + +template +__global__ void rotary_embedding_kernel( + scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ cos_cache, // [num_tokens, rot_dim] + const scalar_t* __restrict__ sin_cache, // [num_tokens, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + + const scalar_t* cos_ptr = cos_cache + token_idx * rot_dim; + const scalar_t* sin_ptr = sin_cache + token_idx * rot_dim; + + const int nq = num_heads * rot_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = i / rot_dim; + const int64_t token_head = token_idx * query_stride + head_idx * head_size; + const int rot_offset = i % rot_dim; + apply_rotary_embedding(query + token_head, cos_ptr, + sin_ptr, rot_offset, rot_dim); + } + + const int nk = num_kv_heads * rot_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / rot_dim; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + const int rot_offset = i % rot_dim; + apply_rotary_embedding(key + token_head, cos_ptr, + sin_ptr, rot_offset, rot_dim); + } +} + +} // namespace vllm + +#define CALL_ROTARY(T, IS_NEOX) \ + vllm::rotary_embedding_kernel<<>>( \ + reinterpret_cast(query), \ + reinterpret_cast(key), \ + reinterpret_cast(cos_cache), \ + reinterpret_cast(sin_cache), \ + rot_dim, \ + query_stride, \ + key_stride, \ + num_heads, \ + num_kv_heads, \ + head_size); + +extern "C" void rotary_embedding( + void *query, // [num_tokens, num_heads, head_size] + void *key, // [num_tokens, num_kv_heads, head_size] + void *cos_cache, // [num_tokens, rot_dim] + void *sin_cache, // [num_tokens, rot_dim] + int32_t is_neox, + + int32_t head_size, + int64_t num_tokens, + int32_t rot_dim, + int32_t num_heads, + int32_t num_kv_heads, + int64_t query_stride, + int64_t key_stride, + + uint32_t dtype // 0 => f16; 1 => bf16; 2 => f32 + ) { + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * rot_dim, 512)); + const cudaStream_t stream = 0; + const bool is_neox_bool = is_neox; + + if (is_neox_bool) { + if (dtype == 0){ + CALL_ROTARY(half, true); + } else if (dtype == 1) { + CALL_ROTARY(__nv_bfloat16, true); + } else if (dtype == 2) { + CALL_ROTARY(float, true); + } + } else { + if (dtype == 0){ + CALL_ROTARY(half, false); + } else if (dtype == 1) { + CALL_ROTARY(__nv_bfloat16, false); + } else if (dtype == 2) { + CALL_ROTARY(float, false); + } + } +} \ No newline at end of file diff --git a/mistralrs-quant/src/lib.rs b/mistralrs-quant/src/lib.rs index 607d2b24dc..2ff2bb9fbd 100644 --- a/mistralrs-quant/src/lib.rs +++ b/mistralrs-quant/src/lib.rs @@ -24,6 +24,7 @@ mod gguf; mod gptq; mod hqq; mod imatrix; +pub mod rotary; pub mod safetensors; mod static_lora; mod unquantized; diff --git a/mistralrs-quant/src/rotary/ffi.rs b/mistralrs-quant/src/rotary/ffi.rs new file mode 100644 index 0000000000..3839177b56 --- /dev/null +++ b/mistralrs-quant/src/rotary/ffi.rs @@ -0,0 +1,22 @@ +use core::ffi::{c_int, c_long, c_void}; + +extern "C" { + pub(crate) fn rotary_embedding( + query: *const c_void, + key: *const c_void, + cos_cache: *const c_void, + sin_cache: *const c_void, + + is_neox: c_int, + + head_size: c_int, + num_tokens: c_long, + rot_dim: c_int, + num_heads: c_int, + num_kv_heads: c_int, + query_stride: c_long, + key_stride: c_long, + + dtype: u32, + ); +} diff --git a/mistralrs-quant/src/rotary/mod.rs b/mistralrs-quant/src/rotary/mod.rs new file mode 100644 index 0000000000..08f539bd78 --- /dev/null +++ b/mistralrs-quant/src/rotary/mod.rs @@ -0,0 +1,188 @@ +#[cfg(feature = "cuda")] +mod ffi; + +#[cfg(feature = "cuda")] +mod cuda { + use candle_core::cuda_backend::cudarc::driver::DevicePtr; + use candle_core::{DType, Result, Storage, Tensor}; + use half::{bf16, f16}; + use std::ffi::{c_int, c_long}; + + fn apply_rotary_< + T: candle_core::cuda_backend::CudaDType + + candle_core::cuda_backend::cudarc::driver::DeviceRepr, + >( + query: &Tensor, + key: &Tensor, + cos_cache: &Tensor, + sin_cache: &Tensor, + is_neox: bool, + ) -> Result<()> { + let dtype = query.dtype(); + if key.dtype() != dtype || cos_cache.dtype() != dtype || sin_cache.dtype() != dtype { + candle_core::bail!("apply-rotary expects all tensors to have the same dtype"); + } + + let internal_type = match dtype { + DType::F16 => 0, + DType::BF16 => 1, + DType::F32 => 2, + dtype => candle_core::bail!("dtype {dtype:?} is not supported"), + }; + + let (q, q_l) = query.storage_and_layout(); + let q = match &*q { + Storage::Cuda(q) => q, + _ => candle_core::bail!("query must be a cuda tensor"), + }; + + let (k, k_l) = key.storage_and_layout(); + let k = match &*k { + Storage::Cuda(k) => k, + _ => candle_core::bail!("key must be a cuda tensor"), + }; + + let (cc, cc_l) = cos_cache.storage_and_layout(); + let cc = match &*cc { + Storage::Cuda(cc) => cc, + _ => candle_core::bail!("cos_cache must be a cuda tensor"), + }; + + let (sc, sc_l) = sin_cache.storage_and_layout(); + let sc = match &*sc { + Storage::Cuda(sc) => sc, + _ => candle_core::bail!("sin_cache must be a cuda tensor"), + }; + + let q_rank = q_l.stride().len(); + let k_rank = k_l.stride().len(); + let cc_rank = cc_l.stride().len(); + let sc_rank = sc_l.stride().len(); + + if q_rank != 3 || k_rank != 3 { + candle_core::bail!( + "apply-rotary expects input tensors of rank 3 (k: {q_l:?}, v: {k_l:?})" + ) + } + + if cc_rank != 2 || sc_rank != 2 { + candle_core::bail!( + "apply-rotary expects cache tensors of rank 2 (k: {cc_l:?}, v: {sc_l:?})" + ) + } + + // Get cuda slices for all tensors + let q = q.as_cuda_slice::()?; + let k = k.as_cuda_slice::()?; + let cc = cc.as_cuda_slice::()?; + let sc = sc.as_cuda_slice::()?; + + // Get cuda views for all tensors + let q = q.slice(q_l.start_offset()..); + let k = k.slice(k_l.start_offset()..); + let cc = cc.slice(cc_l.start_offset()..); + let sc = sc.slice(sc_l.start_offset()..); + + let (num_tokens, num_heads, head_size) = q_l.shape().dims3()?; + let (num_tokens_kv, num_kv_heads, head_size_kv) = k_l.shape().dims3()?; + + if (num_tokens, head_size) != (num_tokens_kv, head_size_kv) { + candle_core::bail!("shape mismatch q {:?} and k {:?}", q_l.shape(), k_l.shape()) + } + + let rot_dim = cc_l.dims()[1]; + if (num_tokens, rot_dim) != cc_l.shape().dims2()? { + candle_core::bail!( + "shape mismatch cos_cache {:?}, expected {:?}", + cc_l.shape(), + (num_tokens, rot_dim) + ) + } + + if (num_tokens, rot_dim) != sc_l.shape().dims2()? { + candle_core::bail!( + "shape mismatch sin_cache {:?}, expected {:?}", + sc_l.shape(), + (num_tokens, rot_dim) + ) + } + + let query_stride = q_l.stride()[0]; + let key_stride = k_l.stride()[0]; + + let q_ptr = *q.device_ptr() as *const core::ffi::c_void; + let k_ptr = *k.device_ptr() as *const core::ffi::c_void; + let cc_ptr = *cc.device_ptr() as *const core::ffi::c_void; + let sc_ptr = *sc.device_ptr() as *const core::ffi::c_void; + + let neox = if is_neox { 1 } else { 0 }; + + unsafe { + super::ffi::rotary_embedding( + q_ptr, + k_ptr, + cc_ptr, + sc_ptr, + neox, + head_size as c_int, + num_tokens as c_long, + rot_dim as c_int, + num_heads as c_int, + num_kv_heads as c_int, + query_stride as c_long, + key_stride as c_long, + internal_type, + ) + } + Ok(()) + } + + /// Apply Rotary position encoding inplace + /// + /// # Arguments + /// + /// * `query` - Query tensor of shape `(num_tokens, num_heads, head_size)`. + /// * `key` - Key tensor of shape `(num_tokens, num_kv_heads, head_size)`. + /// * `cos_cache` - Aligned cache of shape `(num_tokens, rot_dim)` + /// * `sin_cache` - Aligned cache of shape `(num_tokens, rot_dim)` + /// * `is_neox` - Use neox encoding instead of gpt-j style rotary + pub fn apply_rotary_inplace( + query: &Tensor, + key: &Tensor, + cos_cache: &Tensor, + sin_cache: &Tensor, + is_neox: bool, + ) -> Result<()> { + match key.dtype() { + DType::F16 => apply_rotary_::(query, key, cos_cache, sin_cache, is_neox), + DType::BF16 => apply_rotary_::(query, key, cos_cache, sin_cache, is_neox), + DType::F32 => apply_rotary_::(query, key, cos_cache, sin_cache, is_neox), + dt => { + candle_core::bail!("apply_rotary is only supported for f32, f16 and bf16 ({dt:?})") + } + } + } +} + +#[cfg(feature = "cuda")] +pub use cuda::*; + +/// Apply Rotary position encoding inplace +/// +/// # Arguments +/// +/// * `query` - Query tensor of shape `(num_tokens, num_heads, head_size)`. +/// * `key` - Key tensor of shape `(num_tokens, num_kv_heads, head_size)`. +/// * `cos_cache` - Aligned cache of shape `(num_tokens, rot_dim)` +/// * `sin_cache` - Aligned cache of shape `(num_tokens, rot_dim)` +/// * `is_neox` - Use neox encoding instead of gpt-j style rotary +#[cfg(not(feature = "cuda"))] +pub fn apply_rotary_inplace( + _query: &candle_core::Tensor, + _key: &candle_core::Tensor, + _cos_cache: &candle_core::Tensor, + _sin_cache: &candle_core::Tensor, + _is_neox: bool, +) -> candle_core::Result<()> { + candle_core::bail!("apply_rotary is only supported for cuda"); +}