diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index df450b18788..15ab8541dda 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -76,11 +76,9 @@ #define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1) // Moore Threads -#define GGML_CUDA_MUSA_ARCH_IS_QY1 (__MUSA_ARCH__ <= 210) - -#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000 -#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000 -#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD +#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000 +#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000 +#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD #define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD) #define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2) @@ -203,9 +201,9 @@ typedef float2 dfloat2; #define FAST_FP16_AVAILABLE #endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) #define FP16_MMA_AVAILABLE -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) #if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4)) #define FP16_MMA_AVAILABLE @@ -219,9 +217,9 @@ typedef float2 dfloat2; #define CP_ASYNC_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE -#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1) +#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220) #define FLASH_ATTN_AVAILABLE -#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1) +#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220) static bool fp16_available(const int cc) { return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL; @@ -233,7 +231,8 @@ static bool fast_fp16_available(const int cc) { // To be used for feature selection of external libraries, e.g. cuBLAS. static bool fast_fp16_hardware_available(const int cc) { - return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc); + return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc) || + (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2); } // Any FP16 tensor core instructions are available for ggml code. @@ -242,14 +241,16 @@ static bool fp16_mma_available(const int cc) { return false; #else return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) || - GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc); + GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) || + GGML_CUDA_CC_IS_MTHREADS(cc); #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) } // To be used for feature selection of external libraries, e.g. cuBLAS. static bool fp16_mma_hardware_available(const int cc) { return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || - GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc); + GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) || + (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2); } // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index c5668adb152..f3b794c3644 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -9,7 +9,11 @@ #ifdef FP16_MMA_AVAILABLE #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) #include +#ifdef GGML_USE_MUSA +namespace wmma = mtmusa::wmma; +#else // GGML_USE_MUSA namespace wmma = nvcuda::wmma; +#endif // GGML_USE_MUSA #elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE) #undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers #include diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c442a649243..67673f083d0 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -40,6 +40,9 @@ #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" #include "ggml-cuda/gla.cuh" +#ifdef GGML_USE_MUSA +#include "ggml-musa/mublas.cuh" +#endif // GGML_USE_MUSA #include "ggml.h" #include @@ -1198,9 +1201,12 @@ static void ggml_cuda_op_mul_mat_cublas( const int cc = ggml_cuda_info().devices[id].cc; + const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) || + (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2); + const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT; - if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { + if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { ggml_cuda_pool_alloc src1_as_bf16(ctx.pool(id)); if (src1->type != GGML_TYPE_BF16) { const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type); @@ -1228,7 +1234,7 @@ static void ggml_cuda_op_mul_mat_cublas( const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16); to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream); - } else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) { + } else if (fast_fp16_hardware_available(cc) && use_fp16) { // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 ggml_cuda_pool_alloc src0_as_f16(ctx.pool(id)); if (src0->type != GGML_TYPE_F16) { @@ -1742,6 +1748,52 @@ static __global__ void k_compute_batched_ptrs( ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3; } +#ifndef GGML_USE_MUSA +static void ggml_cuda_mul_mat_batched_cublas_gemm_batched_ex( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const half * src0_f16, const half * src1_f16, char * dst_t, + const size_t nbd2, const size_t nbd3, + const int64_t r2, const int64_t r3, + const int64_t s11, const int64_t s12, const int64_t s13, + const void * alpha, const void * beta, + const cudaDataType_t cu_data_type, + const cublasComputeType_t cu_compute_type, + cudaStream_t main_stream +) { + GGML_TENSOR_BINARY_OP_LOCALS + + // use cublasGemmBatchedEx + const int64_t ne23 = ne12*ne13; + + ggml_cuda_pool_alloc ptrs_src(ctx.pool(), 2*ne23); + ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); + + dim3 block_dims(ne13, ne12); + k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( + src0_f16, src1_f16, dst_t, + ptrs_src.get(), ptrs_dst.get(), + ne12, ne13, + ne23, + nb02, nb03, + src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half), + src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half), + nbd2, nbd3, + r2, r3); + CUDA_CHECK(cudaGetLastError()); + + CUBLAS_CHECK( + cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00, + (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11, + beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0, + ne23, + cu_compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); +} +#endif // GGML_USE_MUSA + static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); @@ -1869,34 +1921,16 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co cu_compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } else { - // use cublasGemmBatchedEx - const int64_t ne23 = ne12*ne13; - - ggml_cuda_pool_alloc ptrs_src(ctx.pool(), 2*ne23); - ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); - - dim3 block_dims(ne13, ne12); - k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( - src0_f16, src1_f16, dst_t, - ptrs_src.get(), ptrs_dst.get(), - ne12, ne13, - ne23, - nb02, nb03, - src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half), - src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half), - nbd2, nbd3, - r2, r3); - CUDA_CHECK(cudaGetLastError()); - - CUBLAS_CHECK( - cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00, - (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11, - beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0, - ne23, - cu_compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + ggml_cuda_mul_mat_batched_cublas_gemm_batched_ex( + ctx, + src0, src1, dst, + src0_f16, src1_f16, dst_t, + nbd2, nbd3, + r2, r3, + s11, s12, s13, + alpha, beta, + cu_data_type, cu_compute_type, + main_stream); } #endif @@ -3009,9 +3043,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return false; } #ifdef GGML_USE_MUSA - if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 && - !ggml_is_transposed(a) && !ggml_is_transposed(b)) { - return false; + const int cc = ggml_cuda_info().devices[dev_ctx->device].cc; + if (b->ne[2]*b->ne[3] > 1 && !ggml_is_transposed(a) && !ggml_is_transposed(b)) { + if (GGML_CUDA_CC_IS_QY1(cc) && op->op == GGML_OP_MUL_MAT && + a->type == GGML_TYPE_F16 && b->type == GGML_TYPE_F16) { + return false; + } + if (GGML_CUDA_CC_IS_QY2(cc) && op->op == GGML_OP_MUL_MAT_ID && + a->type == GGML_TYPE_Q2_K && b->type == GGML_TYPE_F32) { + return false; + } } #endif // GGML_USE_MUSA switch (a->type) { @@ -3038,11 +3079,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_BF16: -#ifdef GGML_USE_MUSA - if (a->type == GGML_TYPE_Q3_K) { - return false; - } -#endif // GGML_USE_MUSA return true; default: return false; diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index 971314debc7..882006b2b61 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -27,7 +27,8 @@ if (MUSAToolkit_FOUND) file(GLOB GGML_HEADERS_MUSA "../ggml-cuda/*.cuh") list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h") - list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh") + file(GLOB HRDS "../ggml-musa/*.cuh") + list(APPEND GGML_HEADERS_MUSA ${HRDS}) file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu") file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") diff --git a/ggml/src/ggml-musa/mublas.cu b/ggml/src/ggml-musa/mublas.cu new file mode 100644 index 00000000000..a2cef83bc53 --- /dev/null +++ b/ggml/src/ggml-musa/mublas.cu @@ -0,0 +1,79 @@ +#include "mublas.cuh" + +static __global__ void k_compute_batched_ptrs( + const half * src0_as_f16, const half * src1_as_f16, char * dst, + const void ** ptrs_src, void ** ptrs_dst, + int64_t ne12, int64_t ne13, + int64_t ne23, + size_t nb02, size_t nb03, + size_t nb12, size_t nb13, + size_t nbd2, size_t nbd3, + int64_t r2, int64_t r3) { + const int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y; + + if (i13 >= ne13 || i12 >= ne12) { + return; + } + + const int64_t i03 = i13 / r3; + const int64_t i02 = i12 / r2; + + ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03; + ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13; + ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3; +} + +void ggml_cuda_mul_mat_batched_cublas_gemm_batched_ex( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const half * src0_f16, const half * src1_f16, char * dst_t, + const size_t nbd2, const size_t nbd3, + const int64_t r2, const int64_t r3, + const int64_t s11, const int64_t s12, const int64_t s13, + const void * alpha, const void * beta, + const cudaDataType_t cu_data_type, + const cublasComputeType_t cu_compute_type, + cudaStream_t main_stream +) { + GGML_TENSOR_BINARY_OP_LOCALS + + // use cublasGemmBatchedEx + const int64_t ne23 = ne12*ne13; + + // Allocate memory for pointer arrays using cudaMalloc to avoid segmentation faults in muBLAS. + const void ** ptrs_src; + void ** ptrs_dst; + CUDA_CHECK(cudaMalloc((void **)&ptrs_src, sizeof(void *)*2*ne23)); + CUDA_CHECK(cudaMalloc((void **)&ptrs_dst, sizeof(void *)*1*ne23)); + + dim3 block_dims(ne13, ne12); + k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( + src0_f16, src1_f16, dst_t, + ptrs_src, ptrs_dst, + ne12, ne13, + ne23, + nb02, nb03, + src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half), + src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half), + nbd2, nbd3, + r2, r3); + CUDA_CHECK(cudaGetLastError()); + + // This operation is essential for musa; without it, generated tokens will + // be garbled and may eventually cause MUBLAS_STATUS_INTERNAL_ERROR. + CUDA_CHECK(cudaDeviceSynchronize()); + + CUBLAS_CHECK( + cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + alpha, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/nb00, + (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, s11, + beta, ( void **) (ptrs_dst + 0*ne23), cu_data_type, ne0, + ne23, + cu_compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + CUDA_CHECK(cudaFree(ptrs_src)); + CUDA_CHECK(cudaFree(ptrs_dst)); +} \ No newline at end of file diff --git a/ggml/src/ggml-musa/mublas.cuh b/ggml/src/ggml-musa/mublas.cuh new file mode 100644 index 00000000000..53641b07e8a --- /dev/null +++ b/ggml/src/ggml-musa/mublas.cuh @@ -0,0 +1,14 @@ +#include "ggml-cuda/common.cuh" + +void ggml_cuda_mul_mat_batched_cublas_gemm_batched_ex( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const half * src0_f16, const half * src1_f16, char * dst_t, + const size_t nbd2, const size_t nbd3, + const int64_t r2, const int64_t r3, + const int64_t s11, const int64_t s12, const int64_t s13, + const void * alpha, const void * beta, + const cudaDataType_t cu_data_type, + const cublasComputeType_t cu_compute_type, + cudaStream_t main_stream +); \ No newline at end of file diff --git a/ggml/src/ggml-musa/mudnn.cuh b/ggml/src/ggml-musa/mudnn.cuh index a63be5755c7..c30128561e8 100644 --- a/ggml/src/ggml-musa/mudnn.cuh +++ b/ggml/src/ggml-musa/mudnn.cuh @@ -1,7 +1,7 @@ #pragma once -#include "../include/ggml.h" -#include "../ggml-cuda/common.cuh" +#include "ggml-cuda/common.cuh" +#include "ggml.h" // Asynchronously copies data from src tensor to dst tensor using the provided context. // Returns a musaError_t indicating success or failure.