Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion mlx/backend/cuda/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ namespace mlx::core {
}

NO_GPU(BlockMaskedMM)
NO_GPU(GatherQMM)
NO_GPU_MULTI(LUF)
NO_GPU_MULTI(QRF)
NO_GPU_MULTI(SVD)
Expand Down
13 changes: 13 additions & 0 deletions mlx/backend/cuda/quantized/qmm/qmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,17 @@ void qmv(
QuantizationMode mode,
cu::CommandEncoder& encoder);

void gather_qmv(
const array& x,
const array& w,
const array& scales,
const std::optional<array>& biases,
const array& lhs_indices,
const array& rhs_indices,
array& out,
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder);

} // namespace mlx::core
316 changes: 293 additions & 23 deletions mlx/backend/cuda/quantized/qmm/qmv.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
// Copyright © 2026 Apple Inc.

#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/quantized/qmm/qmm.h"
#include "mlx/backend/cuda/quantized/quantized_utils.h"
#include "mlx/dtype_utils.h"

#include <cooperative_groups.h>
Expand Down Expand Up @@ -182,38 +185,23 @@ dequant_fma(const T* x, const Q* w, S scale, T bias, float* out) {
}

template <
int rows_per_block,
int elems_per_thread,
int group_size,
bool has_bias,
bool has_residue_k,
typename T,
typename Q,
typename S>
__global__ void qmv_kernel(
__device__ __forceinline__ void qmv_impl(
const T* x,
const Q* w,
const S* scales,
const T* biases,
T* out,
int row,
int n,
int k,
bool broadcast_w) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);

// The row that this warp handles.
int row = block.group_index().x * rows_per_block + warp.meta_group_rank();
if (row >= n) {
return;
}

// Advance pointers of x/out.
int m = grid.dim_blocks().y;
int l = block.group_index().z;
x += block.group_index().y * k + m * k * l;
out += block.group_index().y * n + m * n * l;
int k) {
auto warp = cg::tiled_partition<WARP_SIZE>(cg::this_thread_block());

// For sub-byte Q, pointer moves by 8bits for each advance, e.g. w += 1 would
// move past 2 elements for 4-bit Q.
Expand All @@ -224,11 +212,10 @@ __global__ void qmv_kernel(
int groups_per_row = k / group_size;

// Advance w/scales/biases to current row.
int w_batch = broadcast_w ? 0 : l;
w += (static_cast<int64_t>(row) + n * w_batch) * w_step(k);
scales += (static_cast<int64_t>(row) + n * w_batch) * groups_per_row;
w += static_cast<int64_t>(row) * w_step(k);
scales += static_cast<int64_t>(row) * groups_per_row;
if constexpr (has_bias) {
biases += (static_cast<int64_t>(row) + n * w_batch) * groups_per_row;
biases += static_cast<int64_t>(row) * groups_per_row;
}

// Accumulations of current row.
Expand Down Expand Up @@ -274,6 +261,165 @@ __global__ void qmv_kernel(
}
}

template <
int rows_per_block,
int elems_per_thread,
int group_size,
bool has_bias,
bool has_residue_k,
typename T,
typename Q,
typename S>
__global__ void qmv_kernel(
const T* x,
const Q* w,
const S* scales,
const T* biases,
T* out,
int n,
int k,
bool broadcast_w) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);

// The row that this warp handles.
int row = block.group_index().x * rows_per_block + warp.meta_group_rank();
if (row >= n) {
return;
}

// Advance pointers of x/out for M and batch dimensions.
int m = grid.dim_blocks().y;
int l = block.group_index().z;
x += block.group_index().y * k + m * k * l;
out += block.group_index().y * n + m * n * l;

// Advance w/scales/biases for batch dimension.
constexpr int bits = cute::sizeof_bits_v<Q>;
auto w_step = [&](int idx) { return idx * cuda::std::min(8, bits) / 8; };
int groups_per_row = k / group_size;
int w_batch = broadcast_w ? 0 : l;
w += static_cast<int64_t>(n) * w_batch * w_step(k);
scales += static_cast<int64_t>(n) * w_batch * groups_per_row;
if constexpr (has_bias) {
biases += static_cast<int64_t>(n) * w_batch * groups_per_row;
}

// Row-level compute: dequantize, FMA, reduce, write.
qmv_impl<elems_per_thread, group_size, has_bias, has_residue_k>(
x, w, scales, biases, out, row, n, k);
}

template <
int rows_per_block,
int elems_per_thread,
int group_size,
bool has_bias,
bool has_residue_k,
typename T,
typename Q,
typename S>
__global__ void gather_qmv_kernel(
const T* x,
const uint32_t* w,
const S* scales,
const T* biases,
T* out,
const uint32_t* lhs_indices,
const uint32_t* rhs_indices,
int n,
int k,
int x_batch_ndims,
const __grid_constant__ Shape x_batch_shape,
const __grid_constant__ Strides x_batch_strides,
int w_batch_ndims,
const __grid_constant__ Shape w_batch_shape,
const __grid_constant__ Strides w_batch_strides,
const __grid_constant__ Strides s_batch_strides,
const __grid_constant__ Strides b_batch_strides,
int index_ndims,
const __grid_constant__ Shape index_shape,
const __grid_constant__ Strides lhs_index_strides,
const __grid_constant__ Strides rhs_index_strides,
int output_stride) {
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);

// The row that this warp handles.
int row = block.group_index().x * rows_per_block + warp.meta_group_rank();
if (row >= n) {
return;
}

// Gather: look up batch indices.
uint32_t batch_idx = block.group_index().z;
uint32_t x_idx, w_idx;
if (index_ndims == 1) {
x_idx = lhs_indices[batch_idx * lhs_index_strides[0]];
w_idx = rhs_indices[batch_idx * rhs_index_strides[0]];
} else {
auto [lhs_off, rhs_off] = elem_to_loc(
batch_idx,
index_shape.data(),
lhs_index_strides.data(),
rhs_index_strides.data(),
index_ndims);
x_idx = lhs_indices[lhs_off];
w_idx = rhs_indices[rhs_off];
}

// Offset x using gathered index.
if (x_batch_ndims == 1) {
x += x_idx * x_batch_strides[0];
} else {
x += elem_to_loc(
x_idx, x_batch_shape.data(), x_batch_strides.data(), x_batch_ndims);
}

// Offset w/scales/biases using gathered index.
if (w_batch_ndims == 1) {
w += w_idx * w_batch_strides[0];
scales += w_idx * s_batch_strides[0];
if constexpr (has_bias) {
biases += w_idx * b_batch_strides[0];
}
} else {
if constexpr (has_bias) {
auto [w_off, s_off, b_off] = elem_to_loc(
w_idx,
w_batch_shape.data(),
w_batch_strides.data(),
s_batch_strides.data(),
b_batch_strides.data(),
w_batch_ndims);
w += w_off;
scales += s_off;
biases += b_off;
} else {
auto [w_off, s_off] = elem_to_loc(
w_idx,
w_batch_shape.data(),
w_batch_strides.data(),
s_batch_strides.data(),
w_batch_ndims);
w += w_off;
scales += s_off;
}
}

// Offset output for this batch element.
out += batch_idx * output_stride;

// Advance pointers for M dimension (block.group_index().y).
x += block.group_index().y * k;
out += block.group_index().y * n;

// Reinterpret w as Q* for sub-byte access, then run shared compute.
qmv_impl<elems_per_thread, group_size, has_bias, has_residue_k>(
x, reinterpret_cast<const Q*>(w), scales, biases, out, row, n, k);
}

template <
int group_size,
bool has_bias,
Expand Down Expand Up @@ -433,4 +579,128 @@ void qmv(
});
}

void gather_qmv(
const array& x,
const array& w,
const array& scales,
const std::optional<array>& biases,
const array& lhs_indices,
const array& rhs_indices,
array& out,
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder) {
const char* tag = "[gather_qmm]";
int m = out.shape(-2);
int n = out.shape(-1);
int k = x.shape(-1);
int B = out.size() / (m * n);

// Collapse contiguous dims for index arrays.
auto [idx_shape, idx_strides] = collapse_contiguous_dims(
lhs_indices.shape(), {lhs_indices.strides(), rhs_indices.strides()});

dispatch_element_types(out.dtype(), tag, [&]<typename T>() {
dispatch_quant_types<T>(
bits,
group_size,
mode,
tag,
[&]<typename Q, typename S, int group_size>() {
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
if (biases) {
encoder.set_input_array(*biases);
}
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
encoder.set_output_array(out);

constexpr bool has_bias = !cutlass::has_negative_zero_v<Q>;
constexpr int rows_per_block = 8;
constexpr int elems_per_thread =
(cute::sizeof_bits_v<T> <= 16 && cute::sizeof_bits_v<Q> <= 4) ? 16
: 8;

dim3 num_blocks{
uint32_t(cuda::ceil_div(n, rows_per_block)),
uint32_t(m),
uint32_t(B)};
dim3 block_dims{WARP_SIZE, rows_per_block};
int output_stride = m * n;

auto x_ptr = gpu_ptr<T>(x);
auto w_ptr = gpu_ptr<uint32_t>(w);
auto s_ptr = gpu_ptr<S>(scales);
auto b_ptr = biases ? gpu_ptr<T>(*biases) : (const T*)nullptr;
auto o_ptr = gpu_ptr<T>(out);
auto li_ptr = gpu_ptr<uint32_t>(lhs_indices);
auto ri_ptr = gpu_ptr<uint32_t>(rhs_indices);

int x_batch_ndims = x.ndim() - 2;
auto x_shape_p = const_param(x.shape());
auto x_strides_p = const_param<MAX_NDIM, int64_t>(x.strides());
int w_batch_ndims = w.ndim() - 2;
auto w_shape_p = const_param(w.shape());
auto w_strides_p = const_param<MAX_NDIM, int64_t>(w.strides());
auto s_strides_p = const_param<MAX_NDIM, int64_t>(scales.strides());
auto b_strides_p = biases
? const_param<MAX_NDIM, int64_t>(biases->strides())
: decltype(s_strides_p){};
int index_ndims = idx_shape.size();
auto idx_shape_p = const_param(idx_shape);
auto lhs_idx_strides_p =
const_param<MAX_NDIM, int64_t>(idx_strides[0]);
auto rhs_idx_strides_p =
const_param<MAX_NDIM, int64_t>(idx_strides[1]);

void* args[] = {
&x_ptr,
&w_ptr,
&s_ptr,
&b_ptr,
&o_ptr,
&li_ptr,
&ri_ptr,
&n,
&k,
&x_batch_ndims,
&x_shape_p,
&x_strides_p,
&w_batch_ndims,
&w_shape_p,
&w_strides_p,
&s_strides_p,
&b_strides_p,
&index_ndims,
&idx_shape_p,
&lhs_idx_strides_p,
&rhs_idx_strides_p,
&output_stride};

dispatch_bool(
k % (WARP_SIZE * elems_per_thread), [&](auto has_residue_k) {
auto* kernel = &cu::gather_qmv_kernel<
rows_per_block,
elems_per_thread,
group_size,
has_bias,
has_residue_k.value,
T,
Q,
S>;
encoder.add_kernel_node_raw(
reinterpret_cast<void*>(kernel),
num_blocks,
block_dims,
{},
0,
args);
});
});
});
}

} // namespace mlx::core
Loading
Loading