diff --git a/CMakeLists.txt b/CMakeLists.txt index c231ec0e3f..86066791fa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -149,7 +149,6 @@ function (llama_option_depr TYPE OLD NEW) endif() endfunction() -llama_option_depr(FATAL_ERROR LLAMA_CUBLAS GGML_CUDA) llama_option_depr(WARNING LLAMA_CUDA GGML_CUDA) llama_option_depr(WARNING LLAMA_METAL GGML_METAL) llama_option_depr(WARNING LLAMA_METAL_EMBED_LIBRARY GGML_METAL_EMBED_LIBRARY) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 8d17bc669a..245422e978 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -80,6 +80,12 @@ #include #include + + + +void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + + static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); [[noreturn]] @@ -2503,15 +2509,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_RELU: ggml_cuda_op_relu(ctx, dst); break; + case GGML_UNARY_OP_HARDSWISH: + ggml_cuda_op_hardswish(ctx, dst); + break; case GGML_UNARY_OP_SIGMOID: ggml_cuda_op_sigmoid(ctx, dst); break; case GGML_UNARY_OP_HARDSIGMOID: ggml_cuda_op_hardsigmoid(ctx, dst); break; - case GGML_UNARY_OP_HARDSWISH: - ggml_cuda_op_hardswish(ctx, dst); - break; case GGML_UNARY_OP_EXP: ggml_cuda_op_exp(ctx, dst); break; @@ -4296,6 +4302,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_HARDSWISH: + return true; case GGML_UNARY_OP_ABS: case GGML_UNARY_OP_SGN: case GGML_UNARY_OP_NEG: @@ -4305,7 +4313,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_HARDSIGMOID: - case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_TANH: @@ -4554,7 +4561,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_LOG: - return true; case GGML_OP_SSM_SCAN: { if (op->src[3]->ne[0] == 1) { // Mamba2 diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index d4866067a4..9d00ffe71d 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -21,6 +21,7 @@ static __device__ __forceinline__ float op_gelu(float x) { return ggml_cuda_op_gelu_single(x); } + static __device__ __forceinline__ float op_gelu_erf(float x) { const float SQRT_2_INV = 0.70710678118654752440084436210484f; @@ -194,9 +195,33 @@ void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_op_unary(ctx, dst); } + +// --- Custom HardSwish Implementation by Chandan --- +static __global__ void k_hardswish(const float * src, float * dst, const int k) { + const int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < k) { + float x = src[i]; + // HardSwish: x * min(max(x + 3, 0), 6) / 6 + dst[i] = x * fminf(fmaxf(x + 3.0f, 0.0f), 6.0f) / 6.0f; + } +} + void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst); + const ggml_tensor * src = dst->src[0]; + const float * src_d = (const float *)src->data; + float * dst_d = (float *)dst->data; + + const int64_t num_elements = ggml_nelements(src); + const int block_size = 256; + const int grid_size = (num_elements + block_size - 1) / block_size; + + cudaStream_t stream = ctx.stream(); + k_hardswish<<>>(src_d, dst_d, num_elements); } +// -------------------------------------------------- + + + void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst);