Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,6 @@ impl PhiRotaryEmbedding {
) -> Result<(Tensor, Tensor)> {
let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let all_same = seqlen_offsets.iter().all(|&x| x == seqlen_offsets[0]);

let rot_dim = cos.dim(D::Minus1)? * 2;

Expand All @@ -525,7 +524,7 @@ impl PhiRotaryEmbedding {
let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;

let (q_rot, k_rot) = if all_same {
let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
Expand Down Expand Up @@ -559,7 +558,7 @@ impl PhiRotaryEmbedding {
Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
))
} else if all_same {
} else if seqlen_offsets.len() == 1 {
let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
Expand Down Expand Up @@ -1106,8 +1105,8 @@ impl DeepSeekV2RotaryEmbedding {
seqlen_offsets: &[usize],
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let all_same = seqlen_offsets.iter().all(|&x| x == seqlen_offsets[0]);
if all_same {

if seqlen_offsets.len() == 1 {
let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
let q_embed = candle_nn::rotary_emb::rope_i(&q.contiguous()?, &cos, &sin)?;
Expand Down Expand Up @@ -1296,8 +1295,7 @@ impl Phi4MMRotaryEmbedding {
let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;

let all_same = seqlen_offsets.iter().all(|&x| x == seqlen_offsets[0]);
let (q_rot, k_rot) = if all_same {
let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
Expand Down Expand Up @@ -1577,16 +1575,52 @@ impl RotaryEmbedding {
k: &Tensor,
seqlen_offsets: &[usize],
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let (b_sz, qh, seq_len, n_embd) = q.dims4()?;
let (_b_sz, kh, _seq_len, __n_embd) = k.dims4()?;

let rope = if self.is_gpt_neox {
candle_nn::rotary_emb::rope
} else {
candle_nn::rotary_emb::rope_i
};

let all_same = seqlen_offsets.iter().all(|&x| x == seqlen_offsets[0]);
if all_same {
if cfg!(feature = "cuda") {
let (cos, sin) = if seqlen_offsets.len() == 1 {
(
self.cos.narrow(0, seqlen_offsets[0], seq_len)?,
self.sin.narrow(0, seqlen_offsets[0], seq_len)?,
)
} else {
let mut cos_s = Vec::new();
let mut sin_s = Vec::new();
for offset in seqlen_offsets {
cos_s.push(self.cos.narrow(0, *offset, seq_len)?);
sin_s.push(self.sin.narrow(0, *offset, seq_len)?);
}
(Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?)
};

let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
mistralrs_quant::rotary::apply_rotary_inplace(
&q_embed,
&k_embed,
&cos,
&sin,
self.is_gpt_neox,
)?;
let mut q = q_embed
.reshape((b_sz, seq_len, qh, n_embd))?
.transpose(1, 2)?;
let mut k = k_embed
.reshape((b_sz, seq_len, kh, n_embd))?
.transpose(1, 2)?;
if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
q = q.contiguous()?;
k = k.contiguous()?;
}
Ok((q, k))
} else if seqlen_offsets.len() == 1 {
let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
let q_embed = rope(&q.contiguous()?, &cos, &sin)?;
Expand Down
1 change: 1 addition & 0 deletions mistralrs-quant/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ fn main() {
"kernels/hqq/hqq.cu",
"kernels/ops/ops.cu",
"kernels/bitsandbytes/dequant.cu",
"kernels/rotary/rotary.cu",
];
if cc_over_800 {
lib_files.push("kernels/marlin/marlin_kernel.cu");
Expand Down
27 changes: 27 additions & 0 deletions mistralrs-quant/kernels/rotary/cuda_compat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#ifndef USE_ROCM
#define VLLM_LDG(arg) __ldg(arg)
#else
#define VLLM_LDG(arg) *(arg)
#endif

#ifndef USE_ROCM
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
#else
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
#endif

#ifndef USE_ROCM
#define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
#else
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
#endif

#ifndef USE_ROCM
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
#else
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
#endif
131 changes: 131 additions & 0 deletions mistralrs-quant/kernels/rotary/rotary.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <stdint.h>

#include "cuda_compat.h"

namespace vllm {

template<typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding(
scalar_t* __restrict__ arr,
const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ sin_ptr,
int rot_offset,
int rot_dim)
{
int x_index, y_index;
scalar_t cos, sin;
if (IS_NEOX) {
// GPT-NeoX style rotary embedding.
x_index = rot_offset;
y_index = rot_dim + rot_offset;
cos = VLLM_LDG(cos_ptr + x_index);
sin = VLLM_LDG(sin_ptr + x_index);
} else {
// GPT-J style rotary embedding.
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
cos = VLLM_LDG(cos_ptr + x_index / 2);
sin = VLLM_LDG(sin_ptr + x_index / 2);
}

const scalar_t x = arr[x_index];
const scalar_t y = arr[y_index];
arr[x_index] = x * cos - y * sin;
arr[y_index] = y * cos + x * sin;
}

template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel(
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ cos_cache, // [num_tokens, rot_dim]
const scalar_t* __restrict__ sin_cache, // [num_tokens, rot_dim]
const int rot_dim,
const int64_t query_stride,
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;

const scalar_t* cos_ptr = cos_cache + token_idx * rot_dim;
const scalar_t* sin_ptr = sin_cache + token_idx * rot_dim;

const int nq = num_heads * rot_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / rot_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
sin_ptr, rot_offset, rot_dim);
}

const int nk = num_kv_heads * rot_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / rot_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
sin_ptr, rot_offset, rot_dim);
}
}

} // namespace vllm

#define CALL_ROTARY(T, IS_NEOX) \
vllm::rotary_embedding_kernel<T, IS_NEOX><<<grid, block, 0, stream>>>( \
reinterpret_cast<T*>(query), \
reinterpret_cast<T*>(key), \
reinterpret_cast<T*>(cos_cache), \
reinterpret_cast<T*>(sin_cache), \
rot_dim, \
query_stride, \
key_stride, \
num_heads, \
num_kv_heads, \
head_size);

extern "C" void rotary_embedding(
void *query, // [num_tokens, num_heads, head_size]
void *key, // [num_tokens, num_kv_heads, head_size]
void *cos_cache, // [num_tokens, rot_dim]
void *sin_cache, // [num_tokens, rot_dim]
int32_t is_neox,

int32_t head_size,
int64_t num_tokens,
int32_t rot_dim,
int32_t num_heads,
int32_t num_kv_heads,
int64_t query_stride,
int64_t key_stride,

uint32_t dtype // 0 => f16; 1 => bf16; 2 => f32
) {

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim, 512));
const cudaStream_t stream = 0;
const bool is_neox_bool = is_neox;

if (is_neox_bool) {
if (dtype == 0){
CALL_ROTARY(half, true);
} else if (dtype == 1) {
CALL_ROTARY(__nv_bfloat16, true);
} else if (dtype == 2) {
CALL_ROTARY(float, true);
}
} else {
if (dtype == 0){
CALL_ROTARY(half, false);
} else if (dtype == 1) {
CALL_ROTARY(__nv_bfloat16, false);
} else if (dtype == 2) {
CALL_ROTARY(float, false);
}
}
}
1 change: 1 addition & 0 deletions mistralrs-quant/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ mod gguf;
mod gptq;
mod hqq;
mod imatrix;
pub mod rotary;
pub mod safetensors;
mod static_lora;
mod unquantized;
Expand Down
22 changes: 22 additions & 0 deletions mistralrs-quant/src/rotary/ffi.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use core::ffi::{c_int, c_long, c_void};

extern "C" {
pub(crate) fn rotary_embedding(
query: *const c_void,
key: *const c_void,
cos_cache: *const c_void,
sin_cache: *const c_void,

is_neox: c_int,

head_size: c_int,
num_tokens: c_long,
rot_dim: c_int,
num_heads: c_int,
num_kv_heads: c_int,
query_stride: c_long,
key_stride: c_long,

dtype: u32,
);
}
Loading
Loading