-
Notifications
You must be signed in to change notification settings - Fork 15.9k
feat(cuda): Add highly optimized CUDA kernel for HardSwish activation #17943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -80,6 +80,12 @@ | |
| #include <string> | ||
| #include <vector> | ||
|
|
||
|
|
||
|
|
||
|
|
||
| void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Look at how things are done in this repo: we don't define prototypes here, but in an op-specific file. |
||
|
|
||
|
|
||
| static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); | ||
|
|
||
| [[noreturn]] | ||
|
|
@@ -2503,15 +2509,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg | |
| case GGML_UNARY_OP_RELU: | ||
| ggml_cuda_op_relu(ctx, dst); | ||
| break; | ||
| case GGML_UNARY_OP_HARDSWISH: | ||
| ggml_cuda_op_hardswish(ctx, dst); | ||
| break; | ||
| case GGML_UNARY_OP_SIGMOID: | ||
| ggml_cuda_op_sigmoid(ctx, dst); | ||
| break; | ||
| case GGML_UNARY_OP_HARDSIGMOID: | ||
| ggml_cuda_op_hardsigmoid(ctx, dst); | ||
| break; | ||
| case GGML_UNARY_OP_HARDSWISH: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't move this block, this is unnecessary noise. |
||
| ggml_cuda_op_hardswish(ctx, dst); | ||
| break; | ||
| case GGML_UNARY_OP_EXP: | ||
| ggml_cuda_op_exp(ctx, dst); | ||
| break; | ||
|
|
@@ -4296,6 +4302,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g | |
| switch (op->op) { | ||
| case GGML_OP_UNARY: | ||
| switch (ggml_get_unary_op(op)) { | ||
| case GGML_UNARY_OP_HARDSWISH: | ||
| return true; | ||
| case GGML_UNARY_OP_ABS: | ||
| case GGML_UNARY_OP_SGN: | ||
| case GGML_UNARY_OP_NEG: | ||
|
|
@@ -4305,7 +4313,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g | |
| case GGML_UNARY_OP_RELU: | ||
| case GGML_UNARY_OP_SIGMOID: | ||
| case GGML_UNARY_OP_HARDSIGMOID: | ||
| case GGML_UNARY_OP_HARDSWISH: | ||
| case GGML_UNARY_OP_GELU_ERF: | ||
| case GGML_UNARY_OP_GELU_QUICK: | ||
| case GGML_UNARY_OP_TANH: | ||
|
|
@@ -4554,7 +4561,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g | |
| case GGML_OP_COS: | ||
| case GGML_OP_CLAMP: | ||
| case GGML_OP_LOG: | ||
| return true; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you intend to remove this |
||
| case GGML_OP_SSM_SCAN: { | ||
| if (op->src[3]->ne[0] == 1) { | ||
| // Mamba2 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ static __device__ __forceinline__ float op_gelu(float x) { | |
| return ggml_cuda_op_gelu_single(x); | ||
| } | ||
|
|
||
|
|
||
| static __device__ __forceinline__ float op_gelu_erf(float x) { | ||
| const float SQRT_2_INV = 0.70710678118654752440084436210484f; | ||
|
|
||
|
|
@@ -194,9 +195,33 @@ void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst | |
| ggml_cuda_op_unary<op_hardsigmoid>(ctx, dst); | ||
| } | ||
|
|
||
|
|
||
| // --- Custom HardSwish Implementation by Chandan --- | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why the little comment here? |
||
| static __global__ void k_hardswish(const float * src, float * dst, const int k) { | ||
| const int i = blockIdx.x * blockDim.x + threadIdx.x; | ||
| if (i < k) { | ||
| float x = src[i]; | ||
| // HardSwish: x * min(max(x + 3, 0), 6) / 6 | ||
| dst[i] = x * fminf(fmaxf(x + 3.0f, 0.0f), 6.0f) / 6.0f; | ||
| } | ||
| } | ||
|
|
||
| void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
| ggml_cuda_op_unary<op_hardswish>(ctx, dst); | ||
| const ggml_tensor * src = dst->src[0]; | ||
| const float * src_d = (const float *)src->data; | ||
| float * dst_d = (float *)dst->data; | ||
|
|
||
| const int64_t num_elements = ggml_nelements(src); | ||
| const int block_size = 256; | ||
| const int grid_size = (num_elements + block_size - 1) / block_size; | ||
|
|
||
| cudaStream_t stream = ctx.stream(); | ||
| k_hardswish<<<grid_size, block_size, 0, stream>>>(src_d, dst_d, num_elements); | ||
| } | ||
| // -------------------------------------------------- | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unnecessary decoration. |
||
|
|
||
|
|
||
|
|
||
|
|
||
| void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
| ggml_cuda_op_unary<op_exp>(ctx, dst); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you remove this? This is unrelated with your PR.
If you have a fatal error, just use the correct option.