Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion mlx/backend/cuda/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ namespace mlx::core {
}

NO_GPU(BlockMaskedMM)
NO_GPU(GatherQMM)
NO_GPU_MULTI(LUF)
NO_GPU_MULTI(QRF)
NO_GPU_MULTI(SVD)
Expand Down
13 changes: 13 additions & 0 deletions mlx/backend/cuda/quantized/qmm/qmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,17 @@ void qmv(
QuantizationMode mode,
cu::CommandEncoder& encoder);

void gather_qmv(
const array& x,
const array& w,
const array& scales,
const std::optional<array>& biases,
const array& lhs_indices,
const array& rhs_indices,
array& out,
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder);

} // namespace mlx::core
240 changes: 217 additions & 23 deletions mlx/backend/cuda/quantized/qmm/qmv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -182,38 +182,23 @@ dequant_fma(const T* x, const Q* w, S scale, T bias, float* out) {
}

template <
int rows_per_block,
int elems_per_thread,
int group_size,
bool has_bias,
bool has_residue_k,
typename T,
typename Q,
typename S>
__global__ void qmv_kernel(
__device__ __forceinline__ void qmv_impl(
const T* x,
const Q* w,
const S* scales,
const T* biases,
T* out,
int row,
int n,
int k,
bool broadcast_w) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);

// The row that this warp handles.
int row = block.group_index().x * rows_per_block + warp.meta_group_rank();
if (row >= n) {
return;
}

// Advance pointers of x/out.
int m = grid.dim_blocks().y;
int l = block.group_index().z;
x += block.group_index().y * k + m * k * l;
out += block.group_index().y * n + m * n * l;
int k) {
auto warp = cg::tiled_partition<WARP_SIZE>(cg::this_thread_block());

// For sub-byte Q, pointer moves by 8bits for each advance, e.g. w += 1 would
// move past 2 elements for 4-bit Q.
Expand All @@ -224,11 +209,10 @@ __global__ void qmv_kernel(
int groups_per_row = k / group_size;

// Advance w/scales/biases to current row.
int w_batch = broadcast_w ? 0 : l;
w += (static_cast<int64_t>(row) + n * w_batch) * w_step(k);
scales += (static_cast<int64_t>(row) + n * w_batch) * groups_per_row;
w += static_cast<int64_t>(row) * w_step(k);
scales += static_cast<int64_t>(row) * groups_per_row;
if constexpr (has_bias) {
biases += (static_cast<int64_t>(row) + n * w_batch) * groups_per_row;
biases += static_cast<int64_t>(row) * groups_per_row;
}

// Accumulations of current row.
Expand Down Expand Up @@ -274,6 +258,114 @@ __global__ void qmv_kernel(
}
}

template <
int rows_per_block,
int elems_per_thread,
int group_size,
bool has_bias,
bool has_residue_k,
typename T,
typename Q,
typename S>
__global__ void qmv_kernel(
const T* x,
const Q* w,
const S* scales,
const T* biases,
T* out,
int n,
int k,
bool broadcast_w) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);

// The row that this warp handles.
int row = block.group_index().x * rows_per_block + warp.meta_group_rank();
if (row >= n) {
return;
}

// Advance pointers of x/out for M and batch dimensions.
int m = grid.dim_blocks().y;
int l = block.group_index().z;
x += block.group_index().y * k + m * k * l;
out += block.group_index().y * n + m * n * l;

// Advance w/scales/biases for batch dimension.
constexpr int bits = cute::sizeof_bits_v<Q>;
auto w_step = [&](int idx) { return idx * cuda::std::min(8, bits) / 8; };
int groups_per_row = k / group_size;
int w_batch = broadcast_w ? 0 : l;
w += static_cast<int64_t>(n) * w_batch * w_step(k);
scales += static_cast<int64_t>(n) * w_batch * groups_per_row;
if constexpr (has_bias) {
biases += static_cast<int64_t>(n) * w_batch * groups_per_row;
}

// Row-level compute: dequantize, FMA, reduce, write.
qmv_impl<elems_per_thread, group_size, has_bias, has_residue_k>(
x, w, scales, biases, out, row, n, k);
}

template <
int rows_per_block,
int elems_per_thread,
int group_size,
bool has_bias,
bool has_residue_k,
typename T,
typename Q,
typename S>
__global__ void gather_qmv_kernel(
const T* x,
const uint32_t* w,
const S* scales,
const T* biases,
T* out,
const uint32_t* lhs_indices,
const uint32_t* rhs_indices,
int n,
int k,
int64_t x_batch_stride,
int64_t w_batch_stride,
int64_t s_batch_stride,
int64_t b_batch_stride,
int output_stride) {
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);

// The row that this warp handles.
int row = block.group_index().x * rows_per_block + warp.meta_group_rank();
if (row >= n) {
return;
}

// Gather: look up batch indices.
uint32_t batch_idx = block.group_index().z;
uint32_t x_idx = lhs_indices[batch_idx];
uint32_t w_idx = rhs_indices[batch_idx];

// Offset pointers using gathered indices.
x += x_idx * x_batch_stride;
w += w_idx * w_batch_stride;
scales += w_idx * s_batch_stride;
if constexpr (has_bias) {
biases += w_idx * b_batch_stride;
}

// Offset output for this batch element.
out += batch_idx * output_stride;

// Advance pointers for M dimension (block.group_index().y).
x += block.group_index().y * k;
out += block.group_index().y * n;

// Reinterpret w as Q* for sub-byte access, then run shared compute.
qmv_impl<elems_per_thread, group_size, has_bias, has_residue_k>(
x, reinterpret_cast<const Q*>(w), scales, biases, out, row, n, k);
}

template <
int group_size,
bool has_bias,
Expand Down Expand Up @@ -433,4 +525,106 @@ void qmv(
});
}

void gather_qmv(
const array& x,
const array& w,
const array& scales,
const std::optional<array>& biases,
const array& lhs_indices,
const array& rhs_indices,
array& out,
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder) {
const char* tag = "[gather_qmm]";
int m = out.shape(-2);
int n = out.shape(-1);
int k = x.shape(-1);
int B = out.size() / (m * n);

// Batch strides for contiguous inputs.
int64_t x_batch_stride = x.strides()[0];
int64_t w_batch_stride = w.strides()[0];
int64_t s_batch_stride = scales.strides()[0];
int64_t b_batch_stride =
biases ? biases->strides()[0] : static_cast<int64_t>(0);

dispatch_element_types(out.dtype(), tag, [&]<typename T>() {
dispatch_quant_types<T>(
bits,
group_size,
mode,
tag,
[&]<typename Q, typename S, int group_size>() {
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
if (biases) {
encoder.set_input_array(*biases);
}
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
encoder.set_output_array(out);

constexpr bool has_bias = !cutlass::has_negative_zero_v<Q>;
constexpr int rows_per_block = 8;
constexpr int elems_per_thread =
(cute::sizeof_bits_v<T> <= 16 && cute::sizeof_bits_v<Q> <= 4) ? 16
: 8;

dim3 num_blocks{
uint32_t(cuda::ceil_div(n, rows_per_block)),
uint32_t(m),
uint32_t(B)};
dim3 block_dims{WARP_SIZE, rows_per_block};
int output_stride = m * n;

auto x_ptr = gpu_ptr<T>(x);
auto w_ptr = gpu_ptr<uint32_t>(w);
auto s_ptr = gpu_ptr<S>(scales);
auto b_ptr = biases ? gpu_ptr<T>(*biases) : (const T*)nullptr;
auto o_ptr = gpu_ptr<T>(out);
auto li_ptr = gpu_ptr<uint32_t>(lhs_indices);
auto ri_ptr = gpu_ptr<uint32_t>(rhs_indices);

void* args[] = {
&x_ptr,
&w_ptr,
&s_ptr,
&b_ptr,
&o_ptr,
&li_ptr,
&ri_ptr,
&n,
&k,
&x_batch_stride,
&w_batch_stride,
&s_batch_stride,
&b_batch_stride,
&output_stride};

dispatch_bool(
k % (WARP_SIZE * elems_per_thread), [&](auto has_residue_k) {
auto* kernel = &cu::gather_qmv_kernel<
rows_per_block,
elems_per_thread,
group_size,
has_bias,
has_residue_k.value,
T,
Q,
S>;
encoder.add_kernel_node_raw(
reinterpret_cast<void*>(kernel),
num_blocks,
block_dims,
{},
0,
args);
});
});
});
}

} // namespace mlx::core
72 changes: 72 additions & 0 deletions mlx/backend/cuda/quantized/quantized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,78 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
quantization_mode_to_string(mode_)));
}

void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("GatherQMM::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);

const array& x = inputs[0];
const array& w = inputs[1];
const array& scales = inputs[2];
std::optional<array> biases;
if (inputs.size() == 6) {
biases = inputs[3];
}
const array& lhs_indices = inputs[inputs.size() - 2];
const array& rhs_indices = inputs[inputs.size() - 1];

int M = out.shape(-2);
int N = out.shape(-1);
int K = x.shape(-1);
int B = out.size() / (M * N);

auto supports = [&](auto&& f) {
return f(
x,
w,
scales,
biases,
out,
transpose_,
bits_,
group_size_,
mode_,
encoder.device());
};
bool can_use_qmv = supports(supports_qmv) || supports(supports_fp_qmv);

auto call_qmv = [&]() {
out.set_data(cu::malloc_async(out.nbytes(), encoder));
gather_qmv(
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
out,
bits_,
group_size_,
mode_,
encoder);
};

if (can_use_qmv) {
call_qmv();
return;
}

throw std::runtime_error(
fmt::format(
"[gather_qmm] No implementation for "
"problem shape: {}x{}x{}x{}, transpose: {}, "
"activation: {}, bits: {}, group size: {}, mode: \"{}\".",
M,
N,
K,
B,
transpose_,
dtype_to_string(x.dtype()),
bits_,
group_size_,
quantization_mode_to_string(mode_)));
}

void fast::Quantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
Expand Down