Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 57 additions & 33 deletions mlx/backend/cuda/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
__device__ void rope_impl(
const T* in,
T* out,
int offset,
const int* offset,
float inv_freq,
float scale,
const cuda::std::array<int64_t, 3> strides,
const cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch,
int64_t offset_stride,
int n_head,
uint3 pos,
uint3 dims) {
float L = scale * static_cast<float>(pos.y + offset);
auto n_head_up = N * ((n_head + N - 1) / N);
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
auto batch_idx = (pos.z * N) / n_head_up;
auto batch_offset = offset[batch_idx * offset_stride];
float L = scale * static_cast<float>(pos.y + batch_offset);
auto mat_idx = batch_idx * n_head + head_idx;

// Compute costheta, sintheta
float theta = L * inv_freq;
Expand All @@ -123,20 +129,19 @@ __device__ void rope_impl(
size_t out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
mat_idx * out_strides[0];
out_index_2 = out_index_1 + 1;
in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
mat_idx * out_strides[0];
out_index_2 = out_index_1 + dims.x * out_strides[2];
in_index_1 =
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + dims.x * strides[2];
}
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
// Read and write the output
float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]);
Expand Down Expand Up @@ -167,7 +172,8 @@ __global__ void rope(
float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch,
int64_t offset_stride,
int n_head,
uint3 dims) {
uint3 pos = make_uint3(
blockIdx.x * blockDim.x + threadIdx.x,
Expand All @@ -182,12 +188,13 @@ __global__ void rope(
rope_impl<T, traditional, forward>(
in,
out,
*offset,
offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
offset_stride,
n_head,
pos,
dims);
}
Expand All @@ -202,7 +209,8 @@ __global__ void rope_freqs(
float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch,
int64_t offset_stride,
int n_head,
uint3 dims,
int64_t freq_stride) {
uint3 pos = make_uint3(
Expand All @@ -217,12 +225,13 @@ __global__ void rope_freqs(
rope_impl<T, traditional, forward>(
in,
out,
*offset,
offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
offset_stride,
n_head,
pos,
dims);
}
Expand All @@ -245,23 +254,28 @@ void RoPE::eval_gpu(
auto& offset = inputs[1];
auto& out = outputs[0];

if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
}

cuda::std::array<int64_t, 3> strides;
cuda::std::array<int64_t, 3> out_strides;
bool donated = false;
int ndim = in.ndim();
int dispatch_ndim = in.ndim();

int B = in.shape(0);
int T = in.shape(-2);
int D = in.shape(-1);
size_t mat_size = T * D;
int dispatch_ndim = ndim;
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
dispatch_ndim--;
}
size_t mat_size = in.shape(-2) * in.shape(-1);

int N = 1;
for (int i = 1; i < (ndim - 2); ++i) {
N *= in.shape(i);
}

// We apply rope to less that the whole vector so copy to output and then
// apply in-place.
if (dims_ < in.shape(-1)) {
if (dims_ < D) {
donated = true;
auto ctype =
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
Expand Down Expand Up @@ -302,7 +316,7 @@ void RoPE::eval_gpu(
out_strides[2] = out.strides()[ndim - 1];

// Some flags to help us dispatch below
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
bool single = in.flags().row_contiguous && B == 1 && T == 1;
bool with_freqs = inputs.size() == 3;

auto& encoder = cu::get_command_encoder(s);
Expand All @@ -319,7 +333,7 @@ void RoPE::eval_gpu(
if (single && !with_freqs) {
auto kernel =
cu::rope_single<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
uint2 dims = make_uint2(dims_ / 2, N);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
encoder.add_kernel_node(
kernel,
Expand All @@ -336,7 +350,7 @@ void RoPE::eval_gpu(
} else if (single) {
auto kernel =
cu::rope_single_freqs<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
uint2 dims = make_uint2(dims_ / 2, N);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
encoder.add_kernel_node(
kernel,
Expand All @@ -354,10 +368,14 @@ void RoPE::eval_gpu(
} else if (with_freqs) {
auto kernel =
cu::rope_freqs<DataType, traditional.value, forward.value>;
uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4;
int n_per_thread = 4;
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
uint3 dims = make_uint3(dims_ / 2, T, dimz);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
encoder.add_kernel_node(
kernel,
grid,
Expand All @@ -371,15 +389,20 @@ void RoPE::eval_gpu(
std::log2(base_),
strides,
out_strides,
in.size() / mat_size,
offset_stride,
N,
dims,
inputs[2].strides(0));
} else {
auto kernel = cu::rope<DataType, traditional.value, forward.value>;
uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4;
int n_per_thread = 4;
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
uint3 dims = make_uint3(dims_ / 2, T, dimz);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
encoder.add_kernel_node(
kernel,
grid,
Expand All @@ -392,7 +415,8 @@ void RoPE::eval_gpu(
std::log2(base_),
strides,
out_strides,
in.size() / mat_size,
offset_stride,
N,
dims);
}
});
Expand Down
Loading