Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2626,12 +2626,13 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
((ggml_backend_cuda_buffer_context *)next->buffer->context)->device == device_id;

auto unary_op = (ggml_unary_op)dst->op_params[0];
float limit = *(const float *)(dst->op_params + 1);
if (src0_2) {
ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, src0_1, &local_src1, ids, &local_dst,
dst->src[4], dst->src[5],
(const char *)src0_1->data, src0_2 ? (const char *)src0_2->data : nullptr,
(const float *)src1->data, src1_quantized.get(),
(float *)local_dst.data, 0, src0_1->ne[1], 1, src1_padded_col_size, unary_op, stream);
(float *)local_dst.data, 0, src0_1->ne[1], 1, src1_padded_col_size, unary_op, limit, stream);
} else {
auto local_src0_1 = *src0_1;
local_src0_1.ne[1] /= 2;
Expand All @@ -2642,7 +2643,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
nullptr, nullptr,
(const char *)local_src0_1.data, (const char *)local_src0_2.data,
(const float *)src1->data, src1_quantized.get(),
(float *)local_dst.data, 0, local_src0_1.ne[1], 1, src1_padded_col_size, unary_op, stream);
(float *)local_dst.data, 0, local_src0_1.ne[1], 1, src1_padded_col_size, unary_op, limit, stream);
} else {
GGML_ASSERT(!dst->src[5]);
auto local_bias_1 = *dst->src[4];
Expand All @@ -2653,7 +2654,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
&local_bias_1, &local_bias_2,
(const char *)local_src0_1.data, (const char *)local_src0_2.data,
(const float *)src1->data, src1_quantized.get(),
(float *)local_dst.data, 0, local_src0_1.ne[1], 1, src1_padded_col_size, unary_op, stream);
(float *)local_dst.data, 0, local_src0_1.ne[1], 1, src1_padded_col_size, unary_op, limit, stream);
}
}
CUDA_CHECK(cudaGetLastError());
Expand Down Expand Up @@ -2773,9 +2774,11 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
(float *)dst->data, ggml_nelements(dst), dst_row.ne[0], dst_row.ne[0], dst_row.ne[0],
1.702f, 7.0f, stream);
} else {
float limit = *((const float *)(dst->op_params + 1));
//printf("%s: using limit = %g\n", __func__, limit);
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row),
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst->data);
(float *)dst->data, limit);
}
} else {

Expand All @@ -2801,8 +2804,10 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
(float *)dst->data, ggml_nelements(dst), dst->ne[0], src0_1->ne[1], src0_1->ne[1],
1.702f, 7.0f, stream);
} else {
float limit = *((const float *)(dst->op_params + 1));
//printf("%s: using limit = %g\n", __func__, limit);
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst), dst->ne[0],
(const float *)dst_up_gate_contiguous.get(), (float *)dst->data);
(const float *)dst_up_gate_contiguous.get(), (float *)dst->data, limit);
}
}
CUDA_CHECK(cudaGetLastError());
Expand Down Expand Up @@ -2970,6 +2975,8 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
}

auto unary_op = (ggml_unary_op)dst->op_params[0];
float limit = *(const float *)(dst->op_params + 1);
//printf("%s: using limit = %g\n", __func__, limit);
if (src0_2) {
dst_row.data = dst_gate_contiguous.get();
if (use_quantized_src1) {
Expand All @@ -2993,7 +3000,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
} else {
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row),
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst_gate_contiguous.get());
(float *)dst_gate_contiguous.get(), limit);
}
} else {
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
Expand All @@ -3002,7 +3009,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
1.702f, 7.0f, stream);
} else {
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row)/2, dst->ne[0],
(const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
(const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get(), limit);
}
dst_row.data = dst_gate_contiguous.get();
dst_row.ne[0] /= 2;
Expand Down Expand Up @@ -3065,6 +3072,8 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor

auto stream = ctx.stream();

float limit = *(const float *)(dst->op_params + 1);

auto ne10_padded = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING);
auto nb10_padded = ne10_padded*sizeof(block_q8_1)/QK8_1;
auto quantized_size = nb10_padded*src1->ne[1];
Expand All @@ -3083,7 +3092,7 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
dst->src[4], dst->src[5],
(const char *)src0_1->data, (const char *)src0_2->data, (const float *)src1->data, src1_quantized.get(),
(float *)dst->data, 0, src0_1->ne[1], 1, ne10_padded,
(ggml_unary_op)dst->op_params[0], stream);
(ggml_unary_op)dst->op_params[0], limit, stream);
return;
}

Expand Down Expand Up @@ -3116,8 +3125,9 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
}
}

//printf("%s: using limit = %g\n", __func__, limit);
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst),
(const float *)dst->data, dst_up.get(), (float *)dst->data);
(const float *)dst->data, dst_up.get(), (float *)dst->data, limit);
CUDA_CHECK(cudaGetLastError());

}
Expand Down
29 changes: 17 additions & 12 deletions ggml/src/ggml-cuda/iqk_mmvq_templates.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ static __device__ void iqk_fused_mul_mat_vec_q_kernel(
const void * __restrict__ vup, const void * __restrict__ vgate, const void * __restrict__ vy, float * __restrict__ dst,
const float * __restrict__ bias_u, const float * __restrict__ bias_g,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size,
ggml_unary_op unary_op) {
ggml_unary_op unary_op, float limit) {

constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
Expand Down Expand Up @@ -191,7 +191,12 @@ static __device__ void iqk_fused_mul_mat_vec_q_kernel(
float g = tmp_g[j][threadIdx.x];
float r;
switch (unary_op) {
case GGML_UNARY_OP_SILU: r = u*g/(1 + expf(-g)); break;
case GGML_UNARY_OP_SILU:
{
g = g/(1 + expf(-g));
g = min(g, limit);
r = max(-limit, min(limit, u))*g;
} break;
case GGML_UNARY_OP_RELU: r = fmaxf(g, 0.0f) * u; break;
case GGML_UNARY_OP_GELU: {
constexpr float GELU_COEF_A = 0.044715f;
Expand Down Expand Up @@ -243,7 +248,7 @@ static __global__ void iqk_fused_mul_mat_vec_q(
const void * __restrict__ vx_u, const void * __restrict__ vx_g, const void * __restrict__ vy, float * __restrict__ dst,
const char * __restrict__ ids_data, const void * __restrict__ bias_u, const void * __restrict__ bias_g, const uint64_t bias_nb1,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size,
const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, ggml_unary_op unary_op) {
const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, ggml_unary_op unary_op, float limit) {

int i2 = blockIdx.y;
int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2;
Expand All @@ -256,7 +261,7 @@ static __global__ void iqk_fused_mul_mat_vec_q(
char * cdst = (char *)dst + i2*nb2;
iqk_fused_mul_mat_vec_q_kernel<type, vdr, vec_dot_q_cuda, ncols_y, n_interleaved>(
cx_u, cx_g, cy, (float *)cdst, cx_u_b, cx_g_b,
ncols_x, nrows_x, nrows_y, nrows_dst, row_size, unary_op);
ncols_x, nrows_x, nrows_y, nrows_dst, row_size, unary_op, limit);
}

template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int n_interleaved = 1>
Expand Down Expand Up @@ -307,56 +312,56 @@ static void iqk_mul_mat_vec_q_cuda(const mmvq_args & args, cudaStream_t stream)
args.vx_u, args.vx_g, args.vy, args.dst,
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op, args.limit);
break;
case 2:
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 2, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vx_g, args.vy, args.dst,
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op, args.limit);
break;
case 3:
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 3, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vx_g, args.vy, args.dst,
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op, args.limit);
break;
case 4:
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 4, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vx_g, args.vy, args.dst,
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op, args.limit);
break;
case 5:
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 5, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vx_g, args.vy, args.dst,
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op, args.limit);
break;
case 6:
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 6, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vx_g, args.vy, args.dst,
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op, args.limit);
break;
case 7:
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 7, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vx_g, args.vy, args.dst,
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op, args.limit);
break;
case 8:
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 8, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vx_g, args.vy, args.dst,
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op, args.limit);
break;
default:
GGML_ASSERT(false);
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda/mmvq-args.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@ struct mmvq_args {
const uint64_t ids_nb0;
const uint64_t bias_nb1;
ggml_unary_op unary_op;
float limit;
};

Loading