Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ void ggml_cuda_op_mul_mat_q(
const int64_t src1_padded_row_size, cudaStream_t stream) {

const int64_t ne00 = src0->ne[0];
const int64_t nb01 = src0->nb[1];

const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
Expand All @@ -22,7 +23,6 @@ void ggml_cuda_op_mul_mat_q(
const int64_t ne0 = dst->ne[0];

const int64_t row_diff = row_high - row_low;
const int64_t stride00 = ne00 / ggml_blck_size(src0->type);

int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
Expand All @@ -31,7 +31,7 @@ void ggml_cuda_op_mul_mat_q(
// nrows_dst == nrows of the matrix that the kernel writes into
const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;

const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst};
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, nb01, src1_padded_row_size, src1_ncols, ne11, nrows_dst};

switch (src0->type) {
case GGML_TYPE_Q4_0:
Expand Down Expand Up @@ -91,6 +91,9 @@ void ggml_cuda_op_mul_mat_q(
case GGML_TYPE_IQ4_NL:
mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream);
break;
case GGML_TYPE_IQ4_KS:
mul_mat_q_case<GGML_TYPE_IQ4_KS>(ctx, args, stream);
break;
default:
GGML_ABORT("fatal error");
break;
Expand Down Expand Up @@ -128,6 +131,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_KS:
mmq_supported = true;
break;
default:
Expand Down
Loading