From d40fd4d0f5df2224659183089757be359e5838c9 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 13 Jul 2024 15:38:43 +0000 Subject: [PATCH 01/29] added files --- csrc/ops.h | 6 + csrc/quantization/awq/gemm_kernels.cu | 924 +++++++++++------- csrc/torch_bindings.cpp | 4 + .../layers/fused_moe/fused_moe.py | 58 ++ .../layers/fused_moe/fused_moe_awq.py | 95 ++ vllm/model_executor/layers/fused_moe/layer.py | 10 + .../model_executor/layers/quantization/awq.py | 80 ++ 7 files changed, 826 insertions(+), 351 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/fused_moe_awq.py diff --git a/csrc/ops.h b/csrc/ops.h index fb1099e4fe0c..1e1bfd6263a4 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -67,6 +67,12 @@ torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, int64_t split_k_iters); +torch::Tensor awq_fused_moe(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, torch::Tensor _zeros, + torch::Tensor _topk_weights, torch::Tensor _sorted_token_ids_ptr, + torch::Tensor _expert_ids_ptr, torch::Tensor _num_tokens_post_padded, + bool mul_weights, int split_k_iters); + torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, int64_t split_k_iters, diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 6d6da5f3d874..f721af6b20b0 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -1,13 +1,15 @@ /* Adapted from https://github.com/mit-han-lab/llm-awq @article{lin2023awq, - title={AWQ: Activation-aware Weight Quantization for LLM Compression and -Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, -Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} + title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, + author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, + journal={arXiv}, + year={2023} } */ -#include + +#include #include #include "dequantize.cuh" @@ -18,20 +20,26 @@ namespace vllm { namespace awq { // Pack two half values. -static inline __device__ __host__ unsigned __pack_half2(const half x, - const half y) { - unsigned v0 = *((unsigned short*)&x); - unsigned v1 = *((unsigned short*)&y); +static inline __device__ __host__ unsigned +__pack_half2(const half x, const half y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); return (v1 << 16) | v0; } -template -__global__ void __launch_bounds__(64) - gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters, - half* __restrict__ A, int* __restrict__ B, - half* __restrict__ scaling_factors, - int* __restrict__ zeros, int M, int IC, - int OC, half* __restrict__ C) { +template +__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( + int G, + int split_k_iters, + half* __restrict__ A, + int* __restrict__ B, + half* __restrict__ scaling_factors, + int* __restrict__ zeros, + int M, + int IC, + int OC, + half* __restrict__ C) +{ // Only support matrix n = 64 or 128 assert(N == 64 || N == 128); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 @@ -62,46 +70,43 @@ __global__ void __launch_bounds__(64) static constexpr int row_stride = 2 * 32 * 8 / N; bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N; // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - bool ld_A_flag = - (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + - threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id + bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id // bool wb_C_flag = (threadIdx.x / 4) < M; - half* A_ptr = - A + - (((int)blockIdx_y) / j_factors1 * 16 + - (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * - IC + - (((int)threadIdx.x) % (32 / 8)) * 8; - - int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) + - (((int)threadIdx.x) / (N / 8)) * (OC / 8) + - (((int)blockIdx_y) % j_factors1) * (N / 8) + - (((int)threadIdx.x) % (N / 8)) * 1; - // Why * 1 in the above line? - - half* A_shared_ptr = A_shared + - ((int)threadIdx.y) * row_stride_warp * (32 + 8) + - (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + - (((int)threadIdx.x) % (32 / 8)) * 8; - - half* B_shared_ptr = B_shared + - ((int)threadIdx.y) * (row_stride / 2) * (N + 8) + - (((int)threadIdx.x) / (N / 8)) * (N + 8) + - (((int)threadIdx.x) % (N / 8)) * 8; - - int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) + - ((int)threadIdx.x) % (N / 8); - - half* scaling_factors_ptr = scaling_factors + - (((int)blockIdx_y) % j_factors1) * N + - (((int)threadIdx.x) % (N / 8)) * 8; - - half* C_ptr = - C + - static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim - + (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) + - (((int)threadIdx.x) % 4) * 2; + half* A_ptr = A + + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC + + (((int)threadIdx.x) % (32 / 8)) * 8; + + int* B_ptr = B + + ((int)threadIdx.y) * (OC / 8) * (256 / N) + + (((int)threadIdx.x) / (N / 8)) * (OC / 8) + + (((int)blockIdx_y) % j_factors1) * (N / 8) + + (((int)threadIdx.x) % (N / 8)) * 1; +// Why * 1 in the above line? + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + + (((int)threadIdx.x) % (32 / 8) ) * 8; + + half* B_shared_ptr = B_shared + + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) + + (((int)threadIdx.x) / (N / 8)) * (N + 8) + + (((int)threadIdx.x) % (N / 8)) * 8; + + int* zeros_ptr = zeros + + (((int)blockIdx_y) % j_factors1) * (N / 8) + + ((int)threadIdx.x) % (N / 8); + + half* scaling_factors_ptr = scaling_factors + + (((int)blockIdx_y) % j_factors1) * N + + (((int)threadIdx.x) % (N / 8)) * 8; + + half* C_ptr = C + + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim + + (((int)blockIdx_y) % j_factors1) * N + + ((int)threadIdx.y) * (N / 2) + + (((int)threadIdx.x) % 4) * 2; // preload s.f. and zeros int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; @@ -110,83 +115,57 @@ __global__ void __launch_bounds__(64) int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; __syncthreads(); // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - if (ld_A_flag) { + if (ld_A_flag) + { *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); - } else { + } + else + { *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); } // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); - uint4 B_loaded_scale = - *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); + uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); /* - if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && - threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, - B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, - B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); + if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ + printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); } */ // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { + // B: 32 x 136 (128+8) float16 // each warp: 32 x 4 - // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus - // zero -> WB UINT4 - // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * - // 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) - // * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * - // 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * - // 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * - // 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N) - uint32_t B_loaded = - *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); + // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 + // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); + // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) + uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - // uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / - // 8)) * 8); + //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); - // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x - // % (cta_N / 8)) * 8); + // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); // - zero and * scale - // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = - // q * scale - zero * scale. - asm volatile("sub.f16x2 %0, %1, %2;\n" - : "=r"(B_loaded_fp16.x) - : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" - : "=r"(B_loaded_fp16.x) - : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" - : "=r"(B_loaded_fp16.y) - : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" - : "=r"(B_loaded_fp16.y) - : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" - : "=r"(B_loaded_fp16.z) - : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" - : "=r"(B_loaded_fp16.z) - : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" - : "=r"(B_loaded_fp16.w) - : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" - : "=r"(B_loaded_fp16.w) - : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); /* - if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == - 0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n", - B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); + if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ + printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); } */ // write back - *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = - B_loaded_fp16; + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16; } __syncthreads(); @@ -194,179 +173,120 @@ __global__ void __launch_bounds__(64) { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " - "addr; }\n" - : "=r"(addr) - : "l"((void*)((&(A_shared[(k_0_1 * 16)])) + - (((((int)threadIdx.x) & 15) * 40) + - ((((int)threadIdx.x) >> 4) * 8))))); + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + __asm__ __volatile__( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned*)(A_shared_warp + 0))[0]), - "=r"(((unsigned*)(A_shared_warp + 0))[1]), - "=r"(((unsigned*)(A_shared_warp + 0))[2]), - "=r"(((unsigned*)(A_shared_warp + 0))[3]) - : "r"(addr)); + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) + : "r"(addr) + ); } for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " - "addr; }\n" - : "=r"(addr) - : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + - (((int)threadIdx.y) * (N / 2))) + - (ax1_0 * 16))])) + - (((((int)threadIdx.x) & 15) * (N + 8)) + - ((((int)threadIdx.x) >> 4) * 8))))); + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8)))) + ); __asm__ __volatile__( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]), - "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]), - "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]), - "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3]) - : "r"(addr)); + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr) + ); } } for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), - "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), - "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), - "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned*)(A_shared_warp + 0))[0]), - "r"(((unsigned*)(A_shared_warp + 0))[1]), - "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), - "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), - "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), - "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned*)(A_shared_warp + 0))[0]), - "r"(((unsigned*)(A_shared_warp + 0))[1]), - "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), - "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), - "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), - "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned*)(A_shared_warp + 0))[2]), - "r"(((unsigned*)(A_shared_warp + 0))[3]), - "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), - "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), - "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), - "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned*)(A_shared_warp + 0))[2]), - "r"(((unsigned*)(A_shared_warp + 0))[3]), - "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); } - #else +#else { __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " - "%13};\n" - : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), - "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), - "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), - "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned*)(A_shared_warp + 0))[0]), - "r"(((unsigned*)(A_shared_warp + 0))[1]), - "r"(((unsigned*)(A_shared_warp + 0))[2]), - "r"(((unsigned*)(A_shared_warp + 0))[3]), - "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), - "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " - "%13};\n" - : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), - "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), - "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), - "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned*)(A_shared_warp + 0))[0]), - "r"(((unsigned*)(A_shared_warp + 0))[1]), - "r"(((unsigned*)(A_shared_warp + 0))[2]), - "r"(((unsigned*)(A_shared_warp + 0))[3]), - "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), - "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); } - #endif +#endif } } } - // TODO: Shang: Hoist loop invariance. +// TODO: Shang: Hoist loop invariance. for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { for (int local_id = 0; local_id < 8; ++local_id) { - int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + - ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; - if (row_offset < M) { - *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + - local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); + int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; + if (row_offset < M) + { + *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); } } } #endif } -__global__ void __launch_bounds__(64) - dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors, - int* __restrict__ zeros, half* __restrict__ C, int G) { +__global__ void __launch_bounds__(64) dequantize_weights( + int* __restrict__ B, + half* __restrict__ scaling_factors, + int* __restrict__ zeros, + half* __restrict__ C, + int G, + int in_c, + int out_c +) +{ + if (blockIdx.z > 0) { + B = B + blockIdx.z * in_c * out_c / 8; + scaling_factors = scaling_factors + blockIdx.z * in_c * out_c / G; + zeros = zeros + blockIdx.z * in_c * out_c / G / 8; + C = C + blockIdx.z * in_c * out_c; + } int j_factors1 = 4; int row_stride2 = 4; int split_k_iters = 1; @@ -398,30 +318,14 @@ __global__ void __launch_bounds__(64) uint32_t B_loaded = *(uint32_t*)B_ptr2; uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - asm volatile("sub.f16x2 %0, %1, %2;\n" - : "=r"(B_loaded_fp16.x) - : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" - : "=r"(B_loaded_fp16.x) - : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" - : "=r"(B_loaded_fp16.y) - : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" - : "=r"(B_loaded_fp16.y) - : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" - : "=r"(B_loaded_fp16.z) - : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" - : "=r"(B_loaded_fp16.z) - : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" - : "=r"(B_loaded_fp16.w) - : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" - : "=r"(B_loaded_fp16.w) - : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); *(uint4*)B_shared_ptr2 = B_loaded_fp16; @@ -430,57 +334,309 @@ __global__ void __launch_bounds__(64) } } -} // namespace awq -} // namespace vllm - -torch::Tensor awq_dequantize(torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, int64_t split_k_iters, - int64_t thx, int64_t thy) { - int in_c = _kernel.size(0); - int qout_c = _kernel.size(1); - int out_c = qout_c * 8; - int G = in_c / _scaling_factors.size(0); - - int x_thread = thx; - int y_thread = thy; - - int x_blocks = 1; - int y_blocks = 1; - if (thx == 0) { - x_thread = qout_c; +template +__global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32( + int G, + int split_k_iters, + half* __restrict__ A, + int* __restrict__ B, + half* __restrict__ scaling_factors, + int* __restrict__ zeros, + const float* __restrict__ topk_weights, + const int* __restrict__ sorted_token_ids_ptr, + const int* __restrict__ expert_ids_ptr, + const int* __restrict__ num_tokens_post_padded, + const int num_valid_tokens, + const int top_k, + const int expert_num, + int pad_M, + int M, + int IC, + int OC, + half* __restrict__ C) +{ + // Only support matrix n = 64 or 128 + assert(N == 64 || N == 128); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + assert(false); +#else + int num_tokens = *num_tokens_post_padded; + int j_factors1 = ((OC + N - 1) / N); + int blockIdx_x = 0; + int blockIdx_y = blockIdx.x % ((pad_M + 16 - 1) / 16 * j_factors1); + int blockIdx_z = blockIdx.x / ((pad_M + 16 - 1) / 16 * j_factors1); + int block = blockIdx_y / j_factors1; + if (block * 16 >= num_tokens) return; + + static constexpr uint32_t ZERO = 0x0; + float C_warp[32]; + __shared__ half A_shared[16 * (32 + 8)]; + __shared__ half B_shared[32 * (N + 8)]; + + __shared__ half scaling_factors_shared[N]; + __shared__ half zeros_shared[N]; + + half A_shared_warp[8]; + half B_shared_warp[N / 4]; + for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) { + for (int i = 0; i < 8; ++i) { + C_warp[(j_0_4_init * 8) + i] = 0.0; + } } - if (thy == 0) { - y_thread = in_c; + + static constexpr int row_stride_warp = 32 * 8 / 32; + static constexpr int row_stride = 2 * 32 * 8 / N; + bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N; + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 + + int row = (block * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32); + int token_id = sorted_token_ids_ptr[row]; + bool ld_A_flag = (token_id < num_valid_tokens); + half* A_ptr = A + token_id / top_k * IC + (((int)threadIdx.x) % (32 / 8)) * 8; + + int expert_id = expert_ids_ptr[block]; + B = B + OC * IC / 8 * expert_id; + scaling_factors = scaling_factors + OC * IC / G * expert_id; + zeros = zeros + OC * IC / G / 8 * expert_id; + + int* B_ptr = B + + ((int)threadIdx.y) * (OC / 8) * (256 / N) + + (((int)threadIdx.x) / (N / 8)) * (OC / 8) + + (((int)blockIdx_y) % j_factors1) * (N / 8) + + (((int)threadIdx.x) % (N / 8)) * 1; + // Why * 1 in the above line? + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + + (((int)threadIdx.x) % (32 / 8) ) * 8; + + half* B_shared_ptr = B_shared + + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) + + (((int)threadIdx.x) / (N / 8)) * (N + 8) + + (((int)threadIdx.x) % (N / 8)) * 8; + + int* zeros_ptr = zeros + + (((int)blockIdx_y) % j_factors1) * (N / 8) + + ((int)threadIdx.x) % (N / 8); + + half* scaling_factors_ptr = scaling_factors + + (((int)blockIdx_y) % j_factors1) * N + + (((int)threadIdx.x) % (N / 8)) * 8; + + half* C_ptr = C + + static_cast(blockIdx_z) * M * OC * expert_num // blockIdz.x -> split_k dim + + (((int)blockIdx_y) % j_factors1) * N + + ((int)threadIdx.y) * (N / 2) + + (((int)threadIdx.x) % 4) * 2; + + // preload s.f. and zeros + int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; + if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1; + for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { + int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; + __syncthreads(); + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 + if (ld_A_flag) + { + *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); + } + else + { + *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); + } + + uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); + uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); + + int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); + + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { + + uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); + + // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + + // write back + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16; + } + __syncthreads(); + + for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) { + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + + + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) + : "r"(addr) + ); + } + + for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr) + ); + } + } + for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } +#else + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } + +#endif + } + } } - if (thx == 0 && thy == 0) { - x_thread = 8; - y_thread = 8; - x_blocks = (int)(qout_c / 8); - y_blocks = (int)(in_c / 8); + +// TODO: Shang: Hoist loop invariance. + for (int ax1_0_1 = 0; ax1_0_1 < N / 32; ++ax1_0_1) { + for (int local_id = 0; local_id < 8; ++local_id) { + int row_offset = block * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; + int token_id = sorted_token_ids_ptr[row_offset]; + if (token_id < num_valid_tokens) + { + float value = C_warp[(ax1_0_1 * 8) + local_id]; + if (topk_weights) { + value = value * topk_weights[token_id]; + } + *(C_ptr + ax1_0_1 * 16 + token_id * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(value); + } + } } +#endif +} + +} // namespace awq +} // namespace vllm + +torch::Tensor awq_dequantize( + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int split_k_iters, + int thx, + int thy) +{ + int in_c = _kernel.dim() == 2 ? _kernel.size(0) : _kernel.size(1); + int qout_c = _kernel.dim() == 2 ? _kernel.size(1) : _kernel.size(2); + int num_experts = _kernel.dim() == 2 ? 1 : _kernel.size(0); + int out_c = qout_c * 8; + int G = in_c / (_kernel.dim() == 2 ? _scaling_factors.size(0) : _scaling_factors.size(1)); + + int x_thread = thx; + int y_thread = thy; + + int x_blocks = 1; + int y_blocks = 1; + if (thx==0) { + x_thread = qout_c; + } + if (thy==0) { + y_thread = in_c; + } + if (thx==0 && thy==0) { + x_thread = 8; + y_thread = 8; + x_blocks = (int)(qout_c / 8); + y_blocks = (int)(in_c / 8); + } - const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); - auto options = torch::TensorOptions() - .dtype(_scaling_factors.dtype()) - .device(_scaling_factors.device()); - at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); + auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); + at::Tensor _de_kernel; + if (num_experts == 1) { + _de_kernel = torch::empty({in_c, out_c}, options); + } else { + _de_kernel = torch::empty({num_experts, in_c, out_c}, options); + } - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); - auto scaling_factors = - reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); + auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); - dim3 num_blocks(x_blocks, y_blocks); - dim3 threads_per_block(x_thread, y_thread); + dim3 num_blocks(x_blocks, y_blocks, num_experts); + dim3 threads_per_block(x_thread, y_thread); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - vllm::awq::dequantize_weights<<>>( - kernel, scaling_factors, zeros, de_kernel, G); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + vllm::awq::dequantize_weights<<>>( + kernel, scaling_factors, zeros, de_kernel, G, in_c, out_c); - return _de_kernel; + return _de_kernel; } // in_feats: M, IC [float16] @@ -489,61 +645,127 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel, // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] // assume that batch_size < 16 for now -torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, - torch::Tensor _scaling_factors, torch::Tensor _zeros, - int64_t split_k_iters) { - int num_in_feats = _in_feats.size(0); - int num_in_channels = _in_feats.size(1); - const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); - - auto options = torch::TensorOptions() - .dtype(_in_feats.dtype()) - .device(_in_feats.device()); - at::Tensor _out_feats = - torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); - int num_out_feats = _out_feats.size(-2); - int num_out_channels = _out_feats.size(-1); - - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); - auto scaling_factors = - reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); - int group_size = num_in_channels / _scaling_factors.size(0); - - if (num_out_channels % 64 != 0) - throw std::invalid_argument("OC is not multiple of cta_N = 64"); - if (num_out_channels % 8 != 0) - throw std::invalid_argument("OC is not multiple of pack_num = 8"); - if (group_size % 32 != 0) - throw std::invalid_argument("Group size should be a multiple of 32"); - if (num_out_channels % group_size != 0) - throw std::invalid_argument("OC is not multiple of Group size"); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (num_out_channels % 128 == 0) { - int j_factors1 = num_out_channels / 128 / 1; - dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128> - <<>>( +torch::Tensor awq_gemm( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int split_k_iters) +{ + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); + + auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); + at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); + int num_out_feats = _out_feats.size(-2); + int num_out_channels = _out_feats.size(-1); + + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + int group_size = num_in_channels / _scaling_factors.size(0); + + if (num_out_channels % 64 != 0) + throw std::invalid_argument("OC is not multiple of cta_N = 64"); + if (num_out_channels % 8 != 0) + throw std::invalid_argument("OC is not multiple of pack_num = 8"); + if (group_size % 32 != 0) + throw std::invalid_argument("Group size should be a multiple of 32"); + if (num_out_channels % group_size != 0) + throw std::invalid_argument("OC is not multiple of Group size"); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (num_out_channels % 128 == 0) + { + int j_factors1 = num_out_channels / 128 / 1; + dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, + num_out_channels, out_feats); + } + else if (num_out_channels % 64 == 0) + { + int j_factors1 = num_out_channels / 64 / 1; + dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, + num_out_channels, out_feats); + } + return _out_feats.sum(0); +} + +torch::Tensor awq_group_gemm( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + torch::Tensor _topk_weights, + torch::Tensor _sorted_token_ids_ptr, + torch::Tensor _expert_ids_ptr, + torch::Tensor _num_tokens_post_padded, + bool mul_weights, + int split_k_iters) +{ + int num_in_feats = _in_feats.size(0); + int pad_num_in_feats = _sorted_token_ids_ptr.size(0); + int num_in_channels = _in_feats.size(2); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); + + auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); + int num_experts = _topk_weights.size(1); + int top_k = num_experts / _in_feats.size(1); + int group_size = num_in_channels / _scaling_factors.size(1); + + at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _topk_weights.size(1), _kernel.size(2) * 8}, options); + int num_out_channels = _out_feats.size(-1); + + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto topk_weights = mul_weights ? reinterpret_cast(_topk_weights.data_ptr()) : nullptr; + auto sorted_token_ids_ptr = reinterpret_cast(_sorted_token_ids_ptr.data_ptr()); + auto expert_ids_ptr = reinterpret_cast(_expert_ids_ptr.data_ptr()); + auto num_tokens_post_padded = reinterpret_cast(_num_tokens_post_padded.data_ptr()); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (num_out_channels % 128 == 0) + { + int j_factors1 = num_out_channels / 128 / 1; + dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::group_gemm_forward_4bit_cuda_m16nXk32<128><<>>( group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, + topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded, + _topk_weights.numel(), top_k, num_experts, pad_num_in_feats, num_in_feats, num_in_channels, num_out_channels, out_feats); - } else if (num_out_channels % 64 == 0) { - int j_factors1 = num_out_channels / 64 / 1; - dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * - split_k_iters); - - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64> - <<>>( + } + else if (num_out_channels % 64 == 0) + { + int j_factors1 = num_out_channels / 64 / 1; + dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::group_gemm_forward_4bit_cuda_m16nXk32<64><<>>( group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, + topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded, + _topk_weights.numel(), top_k, num_experts, pad_num_in_feats, num_in_feats, num_in_channels, num_out_channels, out_feats); - } - return _out_feats.sum(0); + } + return _out_feats.sum(0); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 18331a674eeb..55aeeef458d8 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -117,6 +117,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("awq_gemm", &awq_gemm); ops.impl("awq_gemm", torch::kCUDA, &awq_gemm); + // Quantized Grouped GEMM for AWQ. + ops.def("awq_fused_moe", &awq_fused_moe); + ops.def("awq_fused_moe", torch::kCUDA, &awq_fused_moe); + // Dequantization for AWQ. ops.def("awq_dequantize", &awq_dequantize); ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 3c62008fbfcc..fec2e378087c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -606,3 +606,61 @@ def fused_moe( w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale) + +def fused_moe_awq( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + w1_qzero: torch.Tensor, + w2_qzero: torch.Tensor, + a1_scale: torch.Tensor, + a2_scale: torch.Tensor, + inplace: bool = False, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - w1_scale (torch.Tensor): scale to be used for w1. + - w2_scale (torch.Tensor): scale to be used for w2. + - w1_qzero (torch.Tensor): zero point to be used for w1. + - w2_qzero (torch.Tensor): zero point to be used for w2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + + # If large seq_len prefill, dequantize and use the fp16 MoE kernel. + do_naive_dequant = hidden_states.shape[:-1].numel() >= 1024 + if do_naive_dequant: + dequant_w1 = ops.awq_dequantize( + w1, w1_scale, w1_qzero, 0, 0,0).permute(0, 2, 1) + dequant_w2 = ops.awq_dequantize( + w2, w2_scale, w2_qzero, 0, 0,0).permute(0, 2, 1) + + return fused_moe( + hidden_states=hidden_states, + w1=dequant_w1, + w2=dequant_w2, + gating_output=gating_output, + topk=topk, + renormalize=renormalize) + + else: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_awq.py b/vllm/model_executor/layers/fused_moe/fused_moe_awq.py new file mode 100644 index 000000000000..dc84ee79e149 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_moe_awq.py @@ -0,0 +1,95 @@ +"""Fused MoE utilities for AWQ.""" +import torch + +from vllm import _custom_ops as ops +from vllm.logger import init_logger + +from .fused_moe import fused_moe, moe_align_block_size, fused_topk + +logger = init_logger(__name__) + +def fused_moe_awq( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + w1_qzero: torch.Tensor, + w2_qzero: torch.Tensor, + pack_factor: int, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - pack_factor (int): Weight packing factor (int4 in int32 == 8) + - w1_scale (torch.Tensor): scale to be used for w1. + - w2_scale (torch.Tensor): scale to be used for w2. + - w1_qzero (torch.Tensor): zero point to be used for w1. + - w2_qzero (torch.Tensor): zero point to be used for w2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + + # If large seq_len prefill, dequantize and use the fp16 MoE kernel. + do_naive_dequant = hidden_states.shape[:-1].numel() >= 1024 + if do_naive_dequant: + dequant_w1 = ops.awq_dequantize( + w1, w1_scale, w1_qzero, 0, 0,0).permute(0, 2, 1) + dequant_w2 = ops.awq_dequantize( + w2, w2_scale, w2_qzero, 0, 0,0).permute(0, 2, 1) + + return fused_moe(hidden_states, + dequant_w1, + dequant_w2, + gating_output, + topk, + renormalize) + + topk_weights, topk_ids = fused_topk(gating_output, topk, renormalize) + (sorted_token_ids, expert_ids, + num_tokens_post_padded) = moe_align_block_size( + topk_ids, 16, w1.shape[0]) + + x = x.view(x.shape[0], 1, *x.shape[1:]) + + gate_up = ops.awq_group_gemm(x, + w1, + w1_scale, + w1_qzero, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + pack_factor) + + out = torch.empty((gate_up.shape[:-1] + (gate_up.shape[-1] // 2, )), + dtype=x.dtype, + device=x.device) + ops.silu_and_mul(out, gate_up) + + out = ops.awq_group_gemm(out, + w2, + w2_scale, + w2_qzero, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + pack_factor) + + return torch.sum(out, dim=1) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 73cfcd7fc85f..e45ac2b7f2d2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -161,6 +161,16 @@ def weight_loader(self, param: torch.nn.Parameter, else: tp_rank = get_tensor_model_parallel_rank() shard_size = self.intermediate_size_per_partition + + # If packed parameter (e.g. AWQ) and packing is on the + # same dimension as TP sharding, adjust indexing by + # pack factor (8 if int4 packed into int32 parameter). + sharded_dim = 0 if (shard_id == 0 or shard_id == 2) else 1 + packed_dim = getattr(param, "packed_dim", None) + if packed_dim is not None and packed_dim == sharded_dim: + assert hasattr(param, "pack_factor") + shard_size = shard_size // param.pack_factor + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) # w1, gate_proj case: Load into first shard of w13. diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index a3854f70bb4f..324db18360bb 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -4,6 +4,7 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase) from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -174,3 +175,82 @@ def apply(self, if bias is not None: out.add_(bias) return out.reshape(out_shape) + +class AWQMoEMethod(FusedMoEMethodBase): + def __init__(self, quant_config: AWQConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + # WEIGHTS + w13_weight = Parameter( + torch.empty(num_experts, + hidden_size, + 2*intermediate_size // self.quant_config.pack_factor, + dtype=torch.int32), + required_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, { + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + **extra_weight_attrs}) + + w2_weight = Parameter( + torch.empty(num_experts, + intermediate_size, + hidden_size // self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, { + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + **extra_weight_attrs}) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_scales = Parameter( + torch.empty(num_experts, + hidden_size // self.quant_config.group_size, + intermediate_size * 2, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + + w2_scales = Parameter( + torch.empty(num_experts, + intermediate_size // self.quant_config.group_size, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_scale", w2_scales) + set_weight_attrs(w2_scales, **extra_weight_attrs) + + # WEIGHT_ZERO_POINT + # Allocate 2 zero points for w1 and w3 respectively. + w13_qzeros = Parameter( + torch.empty(num_experts, + hidden_size // self.quant_config.group_size, + 2 * intermediate_size // self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, { + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + **extra_weight_attrs}) + + w2_qzeros = Parameter( + torch.empty(num_experts, + intermediate_size // self.quant_config.group_size, + hidden_size // self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, { + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + **extra_weight_attrs}) From f1d58369977a6d699a4df09d3f8284d99bcbf7ca Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 13 Jul 2024 16:31:45 +0000 Subject: [PATCH 02/29] format --- csrc/ops.h | 10 +- csrc/quantization/awq/gemm_kernels.cu | 1069 ++++++++++------- .../layers/fused_moe/fused_moe.py | 58 - .../layers/fused_moe/fused_moe_awq.py | 59 +- vllm/model_executor/layers/fused_moe/layer.py | 6 +- .../model_executor/layers/quantization/awq.py | 128 +- 6 files changed, 726 insertions(+), 604 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 1e1bfd6263a4..46db627da776 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -68,10 +68,12 @@ torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, int64_t split_k_iters); torch::Tensor awq_fused_moe(torch::Tensor _in_feats, torch::Tensor _kernel, - torch::Tensor _scaling_factors, torch::Tensor _zeros, - torch::Tensor _topk_weights, torch::Tensor _sorted_token_ids_ptr, - torch::Tensor _expert_ids_ptr, torch::Tensor _num_tokens_post_padded, - bool mul_weights, int split_k_iters); + torch::Tensor _scaling_factors, + torch::Tensor _zeros, torch::Tensor _topk_weights, + torch::Tensor _sorted_token_ids_ptr, + torch::Tensor _expert_ids_ptr, + torch::Tensor _num_tokens_post_padded, + bool mul_weights, int split_k_iters); torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _scaling_factors, diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index f721af6b20b0..4e8ecf4c72bc 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -1,14 +1,12 @@ /* Adapted from https://github.com/mit-han-lab/llm-awq @article{lin2023awq, - title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, - author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, - journal={arXiv}, - year={2023} + title={AWQ: Activation-aware Weight Quantization for LLM Compression and +Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, +Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ - #include #include @@ -20,26 +18,20 @@ namespace vllm { namespace awq { // Pack two half values. -static inline __device__ __host__ unsigned -__pack_half2(const half x, const half y) { - unsigned v0 = *((unsigned short *)&x); - unsigned v1 = *((unsigned short *)&y); +static inline __device__ __host__ unsigned __pack_half2(const half x, + const half y) { + unsigned v0 = *((unsigned short*)&x); + unsigned v1 = *((unsigned short*)&y); return (v1 << 16) | v0; } -template -__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( - int G, - int split_k_iters, - half* __restrict__ A, - int* __restrict__ B, - half* __restrict__ scaling_factors, - int* __restrict__ zeros, - int M, - int IC, - int OC, - half* __restrict__ C) -{ +template +__global__ void __launch_bounds__(64) + gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters, + half* __restrict__ A, int* __restrict__ B, + half* __restrict__ scaling_factors, + int* __restrict__ zeros, int M, int IC, + int OC, half* __restrict__ C) { // Only support matrix n = 64 or 128 assert(N == 64 || N == 128); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 @@ -70,43 +62,46 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( static constexpr int row_stride = 2 * 32 * 8 / N; bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N; // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id + bool ld_A_flag = + (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id // bool wb_C_flag = (threadIdx.x / 4) < M; - half* A_ptr = A - + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC - + (((int)threadIdx.x) % (32 / 8)) * 8; - - int* B_ptr = B - + ((int)threadIdx.y) * (OC / 8) * (256 / N) - + (((int)threadIdx.x) / (N / 8)) * (OC / 8) - + (((int)blockIdx_y) % j_factors1) * (N / 8) - + (((int)threadIdx.x) % (N / 8)) * 1; -// Why * 1 in the above line? - - half* A_shared_ptr = A_shared - + ((int)threadIdx.y) * row_stride_warp * (32 + 8) - + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) - + (((int)threadIdx.x) % (32 / 8) ) * 8; - - half* B_shared_ptr = B_shared - + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) - + (((int)threadIdx.x) / (N / 8)) * (N + 8) - + (((int)threadIdx.x) % (N / 8)) * 8; - - int* zeros_ptr = zeros - + (((int)blockIdx_y) % j_factors1) * (N / 8) - + ((int)threadIdx.x) % (N / 8); - - half* scaling_factors_ptr = scaling_factors - + (((int)blockIdx_y) % j_factors1) * N - + (((int)threadIdx.x) % (N / 8)) * 8; - - half* C_ptr = C - + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim - + (((int)blockIdx_y) % j_factors1) * N - + ((int)threadIdx.y) * (N / 2) - + (((int)threadIdx.x) % 4) * 2; + half* A_ptr = + A + + (((int)blockIdx_y) / j_factors1 * 16 + + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * + IC + + (((int)threadIdx.x) % (32 / 8)) * 8; + + int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) + + (((int)threadIdx.x) / (N / 8)) * (OC / 8) + + (((int)blockIdx_y) % j_factors1) * (N / 8) + + (((int)threadIdx.x) % (N / 8)) * 1; + // Why * 1 in the above line? + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + + (((int)threadIdx.x) % (32 / 8)) * 8; + + half* B_shared_ptr = B_shared + + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) + + (((int)threadIdx.x) / (N / 8)) * (N + 8) + + (((int)threadIdx.x) % (N / 8)) * 8; + + int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) + + ((int)threadIdx.x) % (N / 8); + + half* scaling_factors_ptr = scaling_factors + + (((int)blockIdx_y) % j_factors1) * N + + (((int)threadIdx.x) % (N / 8)) * 8; + + half* C_ptr = + C + + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim + + (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) + + (((int)threadIdx.x) % 4) * 2; // preload s.f. and zeros int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; @@ -115,57 +110,83 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; __syncthreads(); // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - if (ld_A_flag) - { + if (ld_A_flag) { *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); - } - else - { + } else { *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); } // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); - uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); + uint4 B_loaded_scale = + *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); /* - if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ - printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); + if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && + threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, + B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, + B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); } */ // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { - // B: 32 x 136 (128+8) float16 // each warp: 32 x 4 - // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 - // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); - // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) - uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); + // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus + // zero -> WB UINT4 + // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * + // 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) + // * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * + // 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * + // 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * + // 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N) + uint32_t B_loaded = + *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); + // uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / + // 8)) * 8); - // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); + // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x + // % (cta_N / 8)) * 8); // - zero and * scale - // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = + // q * scale - zero * scale. + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); /* - if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ - printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); + if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == + 0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n", + B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); } */ // write back - *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16; + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = + B_loaded_fp16; } __syncthreads(); @@ -173,119 +194,185 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) - ); - + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " + "addr; }\n" + : "=r"(addr) + : "l"((void*)((&(A_shared[(k_0_1 * 16)])) + + (((((int)threadIdx.x) & 15) * 40) + + ((((int)threadIdx.x) >> 4) * 8))))); __asm__ __volatile__( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) - : "r"(addr) - ); + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned*)(A_shared_warp + 0))[0]), + "=r"(((unsigned*)(A_shared_warp + 0))[1]), + "=r"(((unsigned*)(A_shared_warp + 0))[2]), + "=r"(((unsigned*)(A_shared_warp + 0))[3]) + : "r"(addr)); } for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8)))) - ); + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " + "addr; }\n" + : "=r"(addr) + : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + + (((int)threadIdx.y) * (N / 2))) + + (ax1_0 * 16))])) + + (((((int)threadIdx.x) & 15) * (N + 8)) + + ((((int)threadIdx.x) >> 4) * 8))))); __asm__ __volatile__( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) - : "r"(addr) - ); + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr)); } } for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } -#else + #else { __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " + "%13};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " + "%13};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } -#endif + #endif } } } -// TODO: Shang: Hoist loop invariance. + // TODO: Shang: Hoist loop invariance. for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { for (int local_id = 0; local_id < 8; ++local_id) { - int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; - if (row_offset < M) - { - *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); + int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; + if (row_offset < M) { + *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); } } } #endif } -__global__ void __launch_bounds__(64) dequantize_weights( - int* __restrict__ B, - half* __restrict__ scaling_factors, - int* __restrict__ zeros, - half* __restrict__ C, - int G, - int in_c, - int out_c -) -{ +__global__ void __launch_bounds__(64) + dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors, + int* __restrict__ zeros, half* __restrict__ C, int G, + int in_c, int out_c) { if (blockIdx.z > 0) { - B = B + blockIdx.z * in_c * out_c / 8; - scaling_factors = scaling_factors + blockIdx.z * in_c * out_c / G; - zeros = zeros + blockIdx.z * in_c * out_c / G / 8; - C = C + blockIdx.z * in_c * out_c; + B = B + blockIdx.z * in_c * out_c / 8; + scaling_factors = scaling_factors + blockIdx.z * in_c * out_c / G; + zeros = zeros + blockIdx.z * in_c * out_c / G / 8; + C = C + blockIdx.z * in_c * out_c; } int j_factors1 = 4; int row_stride2 = 4; @@ -318,14 +405,30 @@ __global__ void __launch_bounds__(64) dequantize_weights( uint32_t B_loaded = *(uint32_t*)B_ptr2; uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); *(uint4*)B_shared_ptr2 = B_loaded_fp16; @@ -334,27 +437,16 @@ __global__ void __launch_bounds__(64) dequantize_weights( } } -template +template __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32( - int G, - int split_k_iters, - half* __restrict__ A, - int* __restrict__ B, - half* __restrict__ scaling_factors, - int* __restrict__ zeros, - const float* __restrict__ topk_weights, - const int* __restrict__ sorted_token_ids_ptr, - const int* __restrict__ expert_ids_ptr, - const int* __restrict__ num_tokens_post_padded, - const int num_valid_tokens, - const int top_k, - const int expert_num, - int pad_M, - int M, - int IC, - int OC, - half* __restrict__ C) -{ + int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, + half* __restrict__ scaling_factors, int* __restrict__ zeros, + const float* __restrict__ topk_weights, + const int* __restrict__ sorted_token_ids_ptr, + const int* __restrict__ expert_ids_ptr, + const int* __restrict__ num_tokens_post_padded, const int num_valid_tokens, + const int top_k, const int expert_num, int pad_M, int M, int IC, int OC, + half* __restrict__ C) { // Only support matrix n = 64 or 128 assert(N == 64 || N == 128); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 @@ -399,36 +491,34 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32( scaling_factors = scaling_factors + OC * IC / G * expert_id; zeros = zeros + OC * IC / G / 8 * expert_id; - int* B_ptr = B - + ((int)threadIdx.y) * (OC / 8) * (256 / N) - + (((int)threadIdx.x) / (N / 8)) * (OC / 8) - + (((int)blockIdx_y) % j_factors1) * (N / 8) - + (((int)threadIdx.x) % (N / 8)) * 1; + int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) + + (((int)threadIdx.x) / (N / 8)) * (OC / 8) + + (((int)blockIdx_y) % j_factors1) * (N / 8) + + (((int)threadIdx.x) % (N / 8)) * 1; // Why * 1 in the above line? - half* A_shared_ptr = A_shared - + ((int)threadIdx.y) * row_stride_warp * (32 + 8) - + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) - + (((int)threadIdx.x) % (32 / 8) ) * 8; + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + + (((int)threadIdx.x) % (32 / 8)) * 8; - half* B_shared_ptr = B_shared - + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) - + (((int)threadIdx.x) / (N / 8)) * (N + 8) - + (((int)threadIdx.x) % (N / 8)) * 8; + half* B_shared_ptr = B_shared + + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) + + (((int)threadIdx.x) / (N / 8)) * (N + 8) + + (((int)threadIdx.x) % (N / 8)) * 8; - int* zeros_ptr = zeros - + (((int)blockIdx_y) % j_factors1) * (N / 8) - + ((int)threadIdx.x) % (N / 8); + int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) + + ((int)threadIdx.x) % (N / 8); - half* scaling_factors_ptr = scaling_factors - + (((int)blockIdx_y) % j_factors1) * N - + (((int)threadIdx.x) % (N / 8)) * 8; + half* scaling_factors_ptr = scaling_factors + + (((int)blockIdx_y) % j_factors1) * N + + (((int)threadIdx.x) % (N / 8)) * 8; - half* C_ptr = C - + static_cast(blockIdx_z) * M * OC * expert_num // blockIdz.x -> split_k dim - + (((int)blockIdx_y) % j_factors1) * N - + ((int)threadIdx.y) * (N / 2) - + (((int)threadIdx.x) % 4) * 2; + half* C_ptr = C + + static_cast(blockIdx_z) * M * OC * + expert_num // blockIdz.x -> split_k dim + + (((int)blockIdx_y) % j_factors1) * N + + ((int)threadIdx.y) * (N / 2) + (((int)threadIdx.x) % 4) * 2; // preload s.f. and zeros int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; @@ -437,38 +527,54 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32( int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; __syncthreads(); // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - if (ld_A_flag) - { + if (ld_A_flag) { *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); - } - else - { + } else { *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); } uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); - uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); + uint4 B_loaded_scale = + *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { - - uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); + uint32_t B_loaded = + *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = + // q * scale - zero * scale. + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); // write back - *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16; + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = + B_loaded_fp16; } __syncthreads(); @@ -476,167 +582,239 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32( { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) - ); - + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " + "addr; }\n" + : "=r"(addr) + : "l"((void*)((&(A_shared[(k_0_1 * 16)])) + + (((((int)threadIdx.x) & 15) * 40) + + ((((int)threadIdx.x) >> 4) * 8))))); __asm__ __volatile__( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) - : "r"(addr) - ); + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned*)(A_shared_warp + 0))[0]), + "=r"(((unsigned*)(A_shared_warp + 0))[1]), + "=r"(((unsigned*)(A_shared_warp + 0))[2]), + "=r"(((unsigned*)(A_shared_warp + 0))[3]) + : "r"(addr)); } for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8)))) - ); + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " + "addr; }\n" + : "=r"(addr) + : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + + (((int)threadIdx.y) * (N / 2))) + + (ax1_0 * 16))])) + + (((((int)threadIdx.x) & 15) * (N + 8)) + + ((((int)threadIdx.x) >> 4) * 8))))); __asm__ __volatile__( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) - : "r"(addr) - ); + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr)); } } for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } -#else + #else { __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " + "%13};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " + "%13};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } -#endif + #endif } } } -// TODO: Shang: Hoist loop invariance. + // TODO: Shang: Hoist loop invariance. for (int ax1_0_1 = 0; ax1_0_1 < N / 32; ++ax1_0_1) { for (int local_id = 0; local_id < 8; ++local_id) { - int row_offset = block * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; + int row_offset = + block * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; int token_id = sorted_token_ids_ptr[row_offset]; - if (token_id < num_valid_tokens) - { + if (token_id < num_valid_tokens) { float value = C_warp[(ax1_0_1 * 8) + local_id]; if (topk_weights) { - value = value * topk_weights[token_id]; + value = value * topk_weights[token_id]; } - *(C_ptr + ax1_0_1 * 16 + token_id * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(value); + *(C_ptr + ax1_0_1 * 16 + token_id * OC + (local_id / 4) * 8 + + local_id % 2) = __float2half(value); } } } #endif } -} // namespace awq -} // namespace vllm - -torch::Tensor awq_dequantize( - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters, - int thx, - int thy) -{ - int in_c = _kernel.dim() == 2 ? _kernel.size(0) : _kernel.size(1); - int qout_c = _kernel.dim() == 2 ? _kernel.size(1) : _kernel.size(2); - int num_experts = _kernel.dim() == 2 ? 1 : _kernel.size(0); - int out_c = qout_c * 8; - int G = in_c / (_kernel.dim() == 2 ? _scaling_factors.size(0) : _scaling_factors.size(1)); - - int x_thread = thx; - int y_thread = thy; - - int x_blocks = 1; - int y_blocks = 1; - if (thx==0) { - x_thread = qout_c; - } - if (thy==0) { - y_thread = in_c; - } - if (thx==0 && thy==0) { - x_thread = 8; - y_thread = 8; - x_blocks = (int)(qout_c / 8); - y_blocks = (int)(in_c / 8); - } +} // namespace awq +} // namespace vllm + +torch::Tensor awq_dequantize(torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, int split_k_iters, int thx, + int thy) { + int in_c = _kernel.dim() == 2 ? _kernel.size(0) : _kernel.size(1); + int qout_c = _kernel.dim() == 2 ? _kernel.size(1) : _kernel.size(2); + int num_experts = _kernel.dim() == 2 ? 1 : _kernel.size(0); + int out_c = qout_c * 8; + int G = in_c / (_kernel.dim() == 2 ? _scaling_factors.size(0) + : _scaling_factors.size(1)); + + int x_thread = thx; + int y_thread = thy; + + int x_blocks = 1; + int y_blocks = 1; + if (thx == 0) { + x_thread = qout_c; + } + if (thy == 0) { + y_thread = in_c; + } + if (thx == 0 && thy == 0) { + x_thread = 8; + y_thread = 8; + x_blocks = (int)(qout_c / 8); + y_blocks = (int)(in_c / 8); + } - const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); - auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); - at::Tensor _de_kernel; - if (num_experts == 1) { - _de_kernel = torch::empty({in_c, out_c}, options); - } else { - _de_kernel = torch::empty({num_experts, in_c, out_c}, options); - } + auto options = torch::TensorOptions() + .dtype(_scaling_factors.dtype()) + .device(_scaling_factors.device()); + at::Tensor _de_kernel; + if (num_experts == 1) { + _de_kernel = torch::empty({in_c, out_c}, options); + } else { + _de_kernel = torch::empty({num_experts, in_c, out_c}, options); + } - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); + auto scaling_factors = + reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); - dim3 num_blocks(x_blocks, y_blocks, num_experts); - dim3 threads_per_block(x_thread, y_thread); + dim3 num_blocks(x_blocks, y_blocks, num_experts); + dim3 threads_per_block(x_thread, y_thread); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - vllm::awq::dequantize_weights<<>>( - kernel, scaling_factors, zeros, de_kernel, G, in_c, out_c); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + vllm::awq::dequantize_weights<<>>( + kernel, scaling_factors, zeros, de_kernel, G, in_c, out_c); - return _de_kernel; + return _de_kernel; } // in_feats: M, IC [float16] @@ -645,127 +823,134 @@ torch::Tensor awq_dequantize( // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] // assume that batch_size < 16 for now -torch::Tensor awq_gemm( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters) -{ - int num_in_feats = _in_feats.size(0); - int num_in_channels = _in_feats.size(1); - const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); - - auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); - at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); - int num_out_feats = _out_feats.size(-2); - int num_out_channels = _out_feats.size(-1); - - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); - int group_size = num_in_channels / _scaling_factors.size(0); - - if (num_out_channels % 64 != 0) - throw std::invalid_argument("OC is not multiple of cta_N = 64"); - if (num_out_channels % 8 != 0) - throw std::invalid_argument("OC is not multiple of pack_num = 8"); - if (group_size % 32 != 0) - throw std::invalid_argument("Group size should be a multiple of 32"); - if (num_out_channels % group_size != 0) - throw std::invalid_argument("OC is not multiple of Group size"); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (num_out_channels % 128 == 0) - { - int j_factors1 = num_out_channels / 128 / 1; - dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<>>( - group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, - num_out_channels, out_feats); - } - else if (num_out_channels % 64 == 0) - { - int j_factors1 = num_out_channels / 64 / 1; - dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<>>( - group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, - num_out_channels, out_feats); - } - return _out_feats.sum(0); -} - -torch::Tensor awq_group_gemm( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - torch::Tensor _topk_weights, - torch::Tensor _sorted_token_ids_ptr, - torch::Tensor _expert_ids_ptr, - torch::Tensor _num_tokens_post_padded, - bool mul_weights, - int split_k_iters) -{ - int num_in_feats = _in_feats.size(0); - int pad_num_in_feats = _sorted_token_ids_ptr.size(0); - int num_in_channels = _in_feats.size(2); - const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); - - auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); - int num_experts = _topk_weights.size(1); - int top_k = num_experts / _in_feats.size(1); - int group_size = num_in_channels / _scaling_factors.size(1); - - at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _topk_weights.size(1), _kernel.size(2) * 8}, options); - int num_out_channels = _out_feats.size(-1); - - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); - auto topk_weights = mul_weights ? reinterpret_cast(_topk_weights.data_ptr()) : nullptr; - auto sorted_token_ids_ptr = reinterpret_cast(_sorted_token_ids_ptr.data_ptr()); - auto expert_ids_ptr = reinterpret_cast(_expert_ids_ptr.data_ptr()); - auto num_tokens_post_padded = reinterpret_cast(_num_tokens_post_padded.data_ptr()); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (num_out_channels % 128 == 0) - { - int j_factors1 = num_out_channels / 128 / 1; - dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - vllm::awq::group_gemm_forward_4bit_cuda_m16nXk32<128><<>>( +torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, torch::Tensor _zeros, + int split_k_iters) { + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); + + auto options = torch::TensorOptions() + .dtype(_in_feats.dtype()) + .device(_in_feats.device()); + at::Tensor _out_feats = + torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); + int num_out_feats = _out_feats.size(-2); + int num_out_channels = _out_feats.size(-1); + + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + auto scaling_factors = + reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + int group_size = num_in_channels / _scaling_factors.size(0); + + if (num_out_channels % 64 != 0) + throw std::invalid_argument("OC is not multiple of cta_N = 64"); + if (num_out_channels % 8 != 0) + throw std::invalid_argument("OC is not multiple of pack_num = 8"); + if (group_size % 32 != 0) + throw std::invalid_argument("Group size should be a multiple of 32"); + if (num_out_channels % group_size != 0) + throw std::invalid_argument("OC is not multiple of Group size"); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (num_out_channels % 128 == 0) { + int j_factors1 = num_out_channels / 128 / 1; + dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128> + <<>>( group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, - topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded, - _topk_weights.numel(), top_k, num_experts, pad_num_in_feats, num_in_feats, num_in_channels, num_out_channels, out_feats); - } - else if (num_out_channels % 64 == 0) - { - int j_factors1 = num_out_channels / 64 / 1; - dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - vllm::awq::group_gemm_forward_4bit_cuda_m16nXk32<64><<>>( + } else if (num_out_channels % 64 == 0) { + int j_factors1 = num_out_channels / 64 / 1; + dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * + split_k_iters); + + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64> + <<>>( group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, - topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded, - _topk_weights.numel(), top_k, num_experts, pad_num_in_feats, num_in_feats, num_in_channels, num_out_channels, out_feats); - } - return _out_feats.sum(0); + } + return _out_feats.sum(0); +} + +torch::Tensor awq_group_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, torch::Tensor _topk_weights, + torch::Tensor _sorted_token_ids_ptr, + torch::Tensor _expert_ids_ptr, + torch::Tensor _num_tokens_post_padded, + bool mul_weights, int split_k_iters) { + int num_in_feats = _in_feats.size(0); + int pad_num_in_feats = _sorted_token_ids_ptr.size(0); + int num_in_channels = _in_feats.size(2); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); + + auto options = torch::TensorOptions() + .dtype(_in_feats.dtype()) + .device(_in_feats.device()); + int num_experts = _topk_weights.size(1); + int top_k = num_experts / _in_feats.size(1); + int group_size = num_in_channels / _scaling_factors.size(1); + + at::Tensor _out_feats = torch::empty( + {split_k_iters, num_in_feats, _topk_weights.size(1), _kernel.size(2) * 8}, + options); + int num_out_channels = _out_feats.size(-1); + + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + auto scaling_factors = + reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto topk_weights = mul_weights + ? reinterpret_cast(_topk_weights.data_ptr()) + : nullptr; + auto sorted_token_ids_ptr = + reinterpret_cast(_sorted_token_ids_ptr.data_ptr()); + auto expert_ids_ptr = reinterpret_cast(_expert_ids_ptr.data_ptr()); + auto num_tokens_post_padded = + reinterpret_cast(_num_tokens_post_padded.data_ptr()); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (num_out_channels % 128 == 0) { + int j_factors1 = num_out_channels / 128 / 1; + dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 * + split_k_iters); + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::group_gemm_forward_4bit_cuda_m16nXk32<128> + <<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, + topk_weights, sorted_token_ids_ptr, expert_ids_ptr, + num_tokens_post_padded, _topk_weights.numel(), top_k, num_experts, + pad_num_in_feats, num_in_feats, num_in_channels, num_out_channels, + out_feats); + } else if (num_out_channels % 64 == 0) { + int j_factors1 = num_out_channels / 64 / 1; + dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 * + split_k_iters); + + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::group_gemm_forward_4bit_cuda_m16nXk32<64> + <<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, + topk_weights, sorted_token_ids_ptr, expert_ids_ptr, + num_tokens_post_padded, _topk_weights.numel(), top_k, num_experts, + pad_num_in_feats, num_in_feats, num_in_channels, num_out_channels, + out_feats); + } + return _out_feats.sum(0); } diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index fec2e378087c..3c62008fbfcc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -606,61 +606,3 @@ def fused_moe( w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale) - -def fused_moe_awq( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - w1_qzero: torch.Tensor, - w2_qzero: torch.Tensor, - a1_scale: torch.Tensor, - a2_scale: torch.Tensor, - inplace: bool = False, -) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - - w1_scale (torch.Tensor): scale to be used for w1. - - w2_scale (torch.Tensor): scale to be used for w2. - - w1_qzero (torch.Tensor): zero point to be used for w1. - - w2_qzero (torch.Tensor): zero point to be used for w2. - - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - - # If large seq_len prefill, dequantize and use the fp16 MoE kernel. - do_naive_dequant = hidden_states.shape[:-1].numel() >= 1024 - if do_naive_dequant: - dequant_w1 = ops.awq_dequantize( - w1, w1_scale, w1_qzero, 0, 0,0).permute(0, 2, 1) - dequant_w2 = ops.awq_dequantize( - w2, w2_scale, w2_qzero, 0, 0,0).permute(0, 2, 1) - - return fused_moe( - hidden_states=hidden_states, - w1=dequant_w1, - w2=dequant_w2, - gating_output=gating_output, - topk=topk, - renormalize=renormalize) - - else: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_awq.py b/vllm/model_executor/layers/fused_moe/fused_moe_awq.py index dc84ee79e149..2ffb1c4ce1de 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_awq.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_awq.py @@ -4,10 +4,11 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger -from .fused_moe import fused_moe, moe_align_block_size, fused_topk +from .fused_moe import fused_moe, fused_topk, moe_align_block_size logger = init_logger(__name__) + def fused_moe_awq( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -46,50 +47,32 @@ def fused_moe_awq( # If large seq_len prefill, dequantize and use the fp16 MoE kernel. do_naive_dequant = hidden_states.shape[:-1].numel() >= 1024 if do_naive_dequant: - dequant_w1 = ops.awq_dequantize( - w1, w1_scale, w1_qzero, 0, 0,0).permute(0, 2, 1) - dequant_w2 = ops.awq_dequantize( - w2, w2_scale, w2_qzero, 0, 0,0).permute(0, 2, 1) - - return fused_moe(hidden_states, - dequant_w1, - dequant_w2, - gating_output, - topk, - renormalize) + dequant_w1 = ops.awq_dequantize(w1, w1_scale, w1_qzero, 0, 0, + 0).permute(0, 2, 1) + dequant_w2 = ops.awq_dequantize(w2, w2_scale, w2_qzero, 0, 0, + 0).permute(0, 2, 1) + + return fused_moe(hidden_states, dequant_w1, dequant_w2, gating_output, + topk, renormalize) - topk_weights, topk_ids = fused_topk(gating_output, topk, renormalize) + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) (sorted_token_ids, expert_ids, - num_tokens_post_padded) = moe_align_block_size( - topk_ids, 16, w1.shape[0]) + num_tokens_post_padded) = moe_align_block_size(topk_ids, 16, w1.shape[0]) - x = x.view(x.shape[0], 1, *x.shape[1:]) + x = hidden_states.view(hidden_states.shape[0], 1, *hidden_states.shape[1:]) - gate_up = ops.awq_group_gemm(x, - w1, - w1_scale, - w1_qzero, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - pack_factor) + gate_up = ops.awq_group_gemm(x, w1, w1_scale, w1_qzero, topk_weights, + sorted_token_ids, expert_ids, + num_tokens_post_padded, False, pack_factor) out = torch.empty((gate_up.shape[:-1] + (gate_up.shape[-1] // 2, )), - dtype=x.dtype, - device=x.device) + dtype=hidden_states.dtype, + device=hidden_states.device) ops.silu_and_mul(out, gate_up) - out = ops.awq_group_gemm(out, - w2, - w2_scale, - w2_qzero, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - True, - pack_factor) + out = ops.awq_group_gemm(out, w2, w2_scale, w2_qzero, topk_weights, + sorted_token_ids, expert_ids, + num_tokens_post_padded, True, pack_factor) return torch.sum(out, dim=1) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index e45ac2b7f2d2..a83505cee8e6 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -161,9 +161,9 @@ def weight_loader(self, param: torch.nn.Parameter, else: tp_rank = get_tensor_model_parallel_rank() shard_size = self.intermediate_size_per_partition - - # If packed parameter (e.g. AWQ) and packing is on the - # same dimension as TP sharding, adjust indexing by + + # If packed parameter (e.g. AWQ) and packing is on the + # same dimension as TP sharding, adjust indexing by # pack factor (8 if int4 packed into int32 parameter). sharded_dim = 0 if (shard_id == 0 or shard_id == 2) else 1 packed_dim = getattr(param, "packed_dim", None) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 324db18360bb..80b3098d1b5a 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -4,7 +4,7 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase) +from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -176,81 +176,91 @@ def apply(self, out.add_(bias) return out.reshape(out_shape) + class AWQMoEMethod(FusedMoEMethodBase): + def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - + # WEIGHTS - w13_weight = Parameter( - torch.empty(num_experts, - hidden_size, - 2*intermediate_size // self.quant_config.pack_factor, - dtype=torch.int32), - required_grad=False) + w13_weight = Parameter(torch.empty(num_experts, + hidden_size, + 2 * intermediate_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, { - "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, - **extra_weight_attrs}) - - w2_weight = Parameter( - torch.empty(num_experts, - intermediate_size, - hidden_size // self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) + set_weight_attrs( + w13_weight, { + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + **extra_weight_attrs + }) + + w2_weight = Parameter(torch.empty(num_experts, + intermediate_size, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, { - "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, - **extra_weight_attrs}) + set_weight_attrs( + w2_weight, { + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + **extra_weight_attrs + }) # WEIGHT_SCALES # Allocate 2 scales for w1 and w3 respectively. - w13_scales = Parameter( - torch.empty(num_experts, - hidden_size // self.quant_config.group_size, - intermediate_size * 2, - dtype=params_dtype), - requires_grad=False) + w13_scales = Parameter(torch.empty(num_experts, + hidden_size // + self.quant_config.group_size, + intermediate_size * 2, + dtype=params_dtype), + requires_grad=False) layer.register_parameter("w13_scales", w13_scales) set_weight_attrs(w13_scales, extra_weight_attrs) - w2_scales = Parameter( - torch.empty(num_experts, - intermediate_size // self.quant_config.group_size, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w2_scales = Parameter(torch.empty(num_experts, + intermediate_size // + self.quant_config.group_size, + hidden_size, + dtype=params_dtype), + requires_grad=False) layer.register_parameter("w2_scale", w2_scales) set_weight_attrs(w2_scales, **extra_weight_attrs) # WEIGHT_ZERO_POINT # Allocate 2 zero points for w1 and w3 respectively. - w13_qzeros = Parameter( - torch.empty(num_experts, - hidden_size // self.quant_config.group_size, - 2 * intermediate_size // self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) + w13_qzeros = Parameter(torch.empty( + num_experts, + hidden_size // self.quant_config.group_size, + 2 * intermediate_size // self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) layer.register_parameter("w13_qzeros", w13_qzeros) - set_weight_attrs(w13_qzeros, { - "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, - **extra_weight_attrs}) - - w2_qzeros = Parameter( - torch.empty(num_experts, - intermediate_size // self.quant_config.group_size, - hidden_size // self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) + set_weight_attrs( + w13_qzeros, { + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + **extra_weight_attrs + }) + + w2_qzeros = Parameter(torch.empty( + num_experts, + intermediate_size // self.quant_config.group_size, + hidden_size // self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) layer.register_parameter("w2_qzeros", w2_qzeros) - set_weight_attrs(w2_qzeros, { - "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, - **extra_weight_attrs}) + set_weight_attrs( + w2_qzeros, { + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + **extra_weight_attrs + }) From 16baf11b4d66a5316b13abd54e67f414074ebea9 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 13 Jul 2024 17:50:14 +0000 Subject: [PATCH 03/29] stash --- csrc/quantization/awq/gemm_kernels.cu | 14 +++++++------- csrc/torch_bindings.cpp | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 4e8ecf4c72bc..513f760f0105 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -882,13 +882,13 @@ torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, return _out_feats.sum(0); } -torch::Tensor awq_group_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, torch::Tensor _topk_weights, - torch::Tensor _sorted_token_ids_ptr, - torch::Tensor _expert_ids_ptr, - torch::Tensor _num_tokens_post_padded, - bool mul_weights, int split_k_iters) { +torch::Tensor awq_fused_moe(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, torch::Tensor _topk_weights, + torch::Tensor _sorted_token_ids_ptr, + torch::Tensor _expert_ids_ptr, + torch::Tensor _num_tokens_post_padded, + bool mul_weights, int split_k_iters) { int num_in_feats = _in_feats.size(0); int pad_num_in_feats = _sorted_token_ids_ptr.size(0); int num_in_channels = _in_feats.size(2); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 55aeeef458d8..e24dc13706e6 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -119,7 +119,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Quantized Grouped GEMM for AWQ. ops.def("awq_fused_moe", &awq_fused_moe); - ops.def("awq_fused_moe", torch::kCUDA, &awq_fused_moe); + ops.impl("awq_fused_moe", torch::kCUDA, &awq_fused_moe); // Dequantization for AWQ. ops.def("awq_dequantize", &awq_dequantize); From 03d9d8e7983c9fb3d78082a108fe7c24ede7c8a1 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 13 Jul 2024 18:15:56 +0000 Subject: [PATCH 04/29] torch library --- csrc/quantization/awq/gemm_kernels.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 513f760f0105..5619266e88c9 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -7,7 +7,7 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ -#include +#include #include #include "dequantize.cuh" @@ -762,8 +762,8 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32( torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _scaling_factors, - torch::Tensor _zeros, int split_k_iters, int thx, - int thy) { + torch::Tensor _zeros, int64_t split_k_iters, int64_t thx, + int64_t thy) { int in_c = _kernel.dim() == 2 ? _kernel.size(0) : _kernel.size(1); int qout_c = _kernel.dim() == 2 ? _kernel.size(1) : _kernel.size(2); int num_experts = _kernel.dim() == 2 ? 1 : _kernel.size(0); @@ -825,7 +825,7 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, - int split_k_iters) { + int64_t split_k_iters) { int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); @@ -888,7 +888,7 @@ torch::Tensor awq_fused_moe(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _sorted_token_ids_ptr, torch::Tensor _expert_ids_ptr, torch::Tensor _num_tokens_post_padded, - bool mul_weights, int split_k_iters) { + bool mul_weights, int64_t split_k_iters) { int num_in_feats = _in_feats.size(0); int pad_num_in_feats = _sorted_token_ids_ptr.size(0); int num_in_channels = _in_feats.size(2); From 54d6a8722a1a977536a2ff4e2ff6e4f50cd2f5cf Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 13 Jul 2024 18:17:01 +0000 Subject: [PATCH 05/29] fixed another torch library --- csrc/ops.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/ops.h b/csrc/ops.h index 46db627da776..387d36eb4222 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -73,7 +73,7 @@ torch::Tensor awq_fused_moe(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _sorted_token_ids_ptr, torch::Tensor _expert_ids_ptr, torch::Tensor _num_tokens_post_padded, - bool mul_weights, int split_k_iters); + bool mul_weights, int64_t split_k_iters); torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _scaling_factors, From 524a94caab0e77da1d2a2b16f34ea9f147a4ee8a Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 13 Jul 2024 20:21:46 +0000 Subject: [PATCH 06/29] first end to end run with tp=1 --- csrc/ops.h | 3 +- vllm/_custom_ops.py | 15 ++++ .../layers/fused_moe/__init__.py | 2 + .../layers/fused_moe/fused_moe_awq.py | 37 ++++----- vllm/model_executor/layers/fused_moe/layer.py | 75 ++++++++++++------- .../model_executor/layers/quantization/awq.py | 62 ++++++++++----- .../model_executor/layers/quantization/fp8.py | 70 +++++++++-------- vllm/model_executor/model_loader/utils.py | 1 + vllm/model_executor/models/mixtral.py | 22 +----- 9 files changed, 172 insertions(+), 115 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 387d36eb4222..ca52b7aa7d5f 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -69,7 +69,8 @@ torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor awq_fused_moe(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, - torch::Tensor _zeros, torch::Tensor _topk_weights, + torch::Tensor _zeros, + torch::Tensor _topk_weights, torch::Tensor _sorted_token_ids_ptr, torch::Tensor _expert_ids_ptr, torch::Tensor _num_tokens_post_padded, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 03308d04012a..ee47d8f577a1 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -177,6 +177,21 @@ def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters) + +def awq_fused_moe(input: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + topk_weights: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: int, + mul_weights: bool, + pack_factor: int) -> torch.Tensor: + return torch.ops._C.awq_fused_moe(input, qweight, scales, qzeros, topk_weights, + sorted_token_ids, expert_ids, num_tokens_post_padded, + mul_weights, pack_factor) + # gptq def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index db837231c6ac..90208f4f6736 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,10 +1,12 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) +from vllm.model_executor.layers.fused_moe.fused_moe_awq import fused_moe_awq from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, FusedMoEMethodBase) __all__ = [ "fused_moe", + "fused_moe_awq", "fused_topk", "fused_experts", "get_config_file_name", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_awq.py b/vllm/model_executor/layers/fused_moe/fused_moe_awq.py index 2ffb1c4ce1de..649b67a6b788 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_awq.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_awq.py @@ -16,11 +16,11 @@ def fused_moe_awq( gating_output: torch.Tensor, topk: int, renormalize: bool, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - w1_qzero: torch.Tensor, - w2_qzero: torch.Tensor, pack_factor: int, + w1_scales: torch.Tensor, + w2_scales: torch.Tensor, + w1_qzeros: torch.Tensor, + w2_qzeros: torch.Tensor, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -35,10 +35,10 @@ def fused_moe_awq( - topk (int): The number of top-k experts to select. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - pack_factor (int): Weight packing factor (int4 in int32 == 8) - - w1_scale (torch.Tensor): scale to be used for w1. - - w2_scale (torch.Tensor): scale to be used for w2. - - w1_qzero (torch.Tensor): zero point to be used for w1. - - w2_qzero (torch.Tensor): zero point to be used for w2. + - w1_scales (torch.Tensor): scale to be used for w1. + - w2_scales (torch.Tensor): scale to be used for w2. + - w1_qzeros (torch.Tensor): zero point to be used for w1. + - w2_qzeros (torch.Tensor): zero point to be used for w2. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -47,10 +47,11 @@ def fused_moe_awq( # If large seq_len prefill, dequantize and use the fp16 MoE kernel. do_naive_dequant = hidden_states.shape[:-1].numel() >= 1024 if do_naive_dequant: - dequant_w1 = ops.awq_dequantize(w1, w1_scale, w1_qzero, 0, 0, - 0).permute(0, 2, 1) - dequant_w2 = ops.awq_dequantize(w2, w2_scale, w2_qzero, 0, 0, - 0).permute(0, 2, 1) + # TODO: why is this not contiguous alreayd? + dequant_w1 = ops.awq_dequantize(w1, w1_scales, w1_qzeros, 0, 0, + 0).permute(0, 2, 1).contiguous() + dequant_w2 = ops.awq_dequantize(w2, w2_scales, w2_qzeros, 0, 0, + 0).permute(0, 2, 1).contiguous() return fused_moe(hidden_states, dequant_w1, dequant_w2, gating_output, topk, renormalize) @@ -62,17 +63,17 @@ def fused_moe_awq( x = hidden_states.view(hidden_states.shape[0], 1, *hidden_states.shape[1:]) - gate_up = ops.awq_group_gemm(x, w1, w1_scale, w1_qzero, topk_weights, - sorted_token_ids, expert_ids, - num_tokens_post_padded, False, pack_factor) + gate_up = ops.awq_fused_moe(x, w1, w1_scales, w1_qzeros, topk_weights, + sorted_token_ids, expert_ids, + num_tokens_post_padded, False, pack_factor) out = torch.empty((gate_up.shape[:-1] + (gate_up.shape[-1] // 2, )), dtype=hidden_states.dtype, device=hidden_states.device) ops.silu_and_mul(out, gate_up) - out = ops.awq_group_gemm(out, w2, w2_scale, w2_qzero, topk_weights, - sorted_token_ids, expert_ids, - num_tokens_post_padded, True, pack_factor) + out = ops.awq_fused_moe(out, w2, w2_scales, w2_qzeros, topk_weights, + sorted_token_ids, expert_ids, + num_tokens_post_padded, True, pack_factor) return torch.sum(out, dim=1) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a83505cee8e6..76865332a842 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -134,17 +134,15 @@ def __init__( intermediate_size=self.intermediate_size_per_partition, params_dtype=params_dtype, weight_loader=self.weight_loader) - - def weight_loader(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, weight_name: str, - shard_id: int, expert_id: int): - param_data = param.data - + + def _load_fp8_scale(self, param_data: torch.Tensor, + loaded_weight: torch.Tensor, + weight_name: str, expert_id: int) -> None: # FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral. # Follow up PR to enable fp8 for other MoE models. if "input_scale" in weight_name or "w2.weight_scale" in weight_name: if param_data[expert_id] != 1 and (param_data[expert_id] - - loaded_weight).abs() > 1e-5: + loaded_weight).abs() > 1e-5: raise ValueError( "input_scales of w1 and w3 of a layer " f"must be equal. But got {param_data[expert_id]} " @@ -158,35 +156,60 @@ def weight_loader(self, param: torch.nn.Parameter, assert "w1" in weight_name or "w3" in weight_name shard_id = 0 if "w1" in weight_name else 1 param_data[expert_id][shard_id] = loaded_weight + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, weight_name: str, + shard_id: int, expert_id: int) -> None: + param_data = param.data + + # Special case for fp8 scales. + if getattr(param, "is_fp8_scale", False): + self._load_fp8_scale(param_data, loaded_weight, weight_name, expert_id) + # Otherwise, load with usual logic. else: tp_rank = get_tensor_model_parallel_rank() + shard_size = self.intermediate_size_per_partition - # If packed parameter (e.g. AWQ) and packing is on the - # same dimension as TP sharding, adjust indexing by - # pack factor (8 if int4 packed into int32 parameter). - sharded_dim = 0 if (shard_id == 0 or shard_id == 2) else 1 + # If packed parameter (and packing is on the same dim as + # TP sharding, adjust indexing by pack factor. packed_dim = getattr(param, "packed_dim", None) - if packed_dim is not None and packed_dim == sharded_dim: - assert hasattr(param, "pack_factor") + if packed_dim and shard_id != 1: shard_size = shard_size // param.pack_factor shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - # w1, gate_proj case: Load into first shard of w13. - if shard_id == 0: - param_data[expert_id, - 0:shard_size, :] = loaded_weight[shard, :] - # w3, up_proj case: Load into second shard of w13. - elif shard_id == 2: - param_data[expert_id, shard_size:2 * - shard_size, :] = loaded_weight[shard, :] - # w2, down_proj case: Load into only shard of w2. - elif shard_id == 1: - param_data[expert_id, :, :] = loaded_weight[:, shard] + # Usually, weight is saved in format [output_dim, input_dim] + # If transposed, weight is saved in format [input_dim, output_dim] + is_transposed = getattr(param, "is_transposed", False) + if is_transposed: + # w1, gate_proj case: Load into first shard of w13. + if shard_id == 0: + param_data[expert_id, :, 0:shard_size] = loaded_weight[:, shard] + # w3, up_proj case: Load into second shard of w13. + elif shard_id == 2: + param_data[expert_id, :, shard_size:2*shard_size] = loaded_weight[:, shard] + # w2, down_proj case: Load into only shard of w2. + elif shard_id == 1: + param_data[expert_id, :, :] = loaded_weight[shard, :] + else: + raise ValueError( + f"Shard id must be in [0,1,2] but got {shard_id}") else: - raise ValueError( - f"Shard id must be in [0,1,2] but got {shard_id}") + # w1, gate_proj case: Load into first shard of w13. + if shard_id == 0: + param_data[expert_id, + 0:shard_size, :] = loaded_weight[shard, :] + # w3, up_proj case: Load into second shard of w13. + elif shard_id == 2: + param_data[expert_id, shard_size:2 * + shard_size, :] = loaded_weight[shard, :] + # w2, down_proj case: Load into only shard of w2. + elif shard_id == 1: + param_data[expert_id, :, :] = loaded_weight[:, shard] + else: + raise ValueError( + f"Shard id must be in [0,1,2] but got {shard_id}") def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 80b3098d1b5a..2cf97ac7dbfe 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -4,7 +4,8 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + fused_moe_awq) from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -68,6 +69,8 @@ def get_quant_method( self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]: if isinstance(layer, LinearBase): return AWQLinearMethod(self) + elif isinstance(layer, FusedMoE): + return AWQMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -187,33 +190,33 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, params_dtype: torch.dtype, **extra_weight_attrs): # WEIGHTS - w13_weight = Parameter(torch.empty(num_experts, + w13_qweight = Parameter(torch.empty(num_experts, hidden_size, 2 * intermediate_size // self.quant_config.pack_factor, dtype=torch.int32), requires_grad=False) - layer.register_parameter("w13_weight", w13_weight) + layer.register_parameter("w13_qweight", w13_qweight) set_weight_attrs( - w13_weight, { + w13_qweight, { "packed_dim": 1, "pack_factor": self.quant_config.pack_factor, - **extra_weight_attrs - }) + "is_transposed": True, + **extra_weight_attrs}) - w2_weight = Parameter(torch.empty(num_experts, + w2_qweight = Parameter(torch.empty(num_experts, intermediate_size, hidden_size // self.quant_config.pack_factor, dtype=torch.int32), requires_grad=False) - layer.register_parameter("w2_weight", w2_weight) + layer.register_parameter("w2_qweight", w2_qweight) set_weight_attrs( - w2_weight, { + w2_qweight, { "packed_dim": 1, "pack_factor": self.quant_config.pack_factor, - **extra_weight_attrs - }) + "is_transposed": True, + **extra_weight_attrs}) # WEIGHT_SCALES # Allocate 2 scales for w1 and w3 respectively. @@ -224,7 +227,9 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, dtype=params_dtype), requires_grad=False) layer.register_parameter("w13_scales", w13_scales) - set_weight_attrs(w13_scales, extra_weight_attrs) + set_weight_attrs(w13_scales, { + "is_transposed": True, + **extra_weight_attrs}) w2_scales = Parameter(torch.empty(num_experts, intermediate_size // @@ -232,8 +237,10 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size, dtype=params_dtype), requires_grad=False) - layer.register_parameter("w2_scale", w2_scales) - set_weight_attrs(w2_scales, **extra_weight_attrs) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, { + "is_transposed": True, + **extra_weight_attrs}) # WEIGHT_ZERO_POINT # Allocate 2 zero points for w1 and w3 respectively. @@ -248,8 +255,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, w13_qzeros, { "packed_dim": 1, "pack_factor": self.quant_config.pack_factor, - **extra_weight_attrs - }) + "is_transposed": True, + **extra_weight_attrs}) w2_qzeros = Parameter(torch.empty( num_experts, @@ -262,5 +269,24 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, w2_qzeros, { "packed_dim": 1, "pack_factor": self.quant_config.pack_factor, - **extra_weight_attrs - }) + "is_transposed": True, + **extra_weight_attrs}) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True) -> torch.Tensor: + + return fused_moe_awq(x, + layer.w13_qweight, + layer.w2_qweight, + router_logits, + top_k, + renormalize=renormalize, + pack_factor=self.quant_config.pack_factor, + w1_scales=layer.w13_scales, + w2_scales=layer.w2_scales, + w1_qzeros=layer.w13_qzeros, + w2_qzeros=layer.w2_qzeros,) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 0c2d2bd3fabe..23ca924d37aa 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -260,23 +260,25 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, # WEIGHT_SCALES # Allocate 2 scales for w1 and w3 respectively. # They will be combined to a single scale after weight loading. - w13_scale = torch.nn.Parameter(torch.ones(num_experts, + w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False) - layer.register_parameter("w13_scale", w13_scale) + layer.register_parameter("w13_weight_scale", w13_weight_scale) - w2_scale = torch.nn.Parameter(torch.ones(num_experts, + w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, dtype=torch.float32), requires_grad=False) - layer.register_parameter("w2_scale", w2_scale) + layer.register_parameter("w13_weight_scale", w13_weight_scale) # If loading fp8 checkpoint, pass the weight loaders. # If loading an fp16 checkpoint, do not (we will quantize in # process_weights_after_loading() if self.quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs(w13_scale, extra_weight_attrs) - set_weight_attrs(w2_scale, extra_weight_attrs) + set_weight_attrs(w13_weight_scale, { + "is_fp8_scale": True, **extra_weight_attrs}) + set_weight_attrs(w13_weight_scale, { + "is_fp8_scale": True, **extra_weight_attrs}) # INPUT_SCALES if self.quant_config.activation_scheme == "static": @@ -285,20 +287,22 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, "Found static activation scheme for checkpoint that " "was not serialized fp8.") - a13_scale = torch.nn.Parameter(torch.ones(num_experts, + w13_input_scale = torch.nn.Parameter(torch.ones(num_experts, dtype=torch.float32), requires_grad=False) - layer.register_parameter("a13_scale", a13_scale) - set_weight_attrs(a13_scale, extra_weight_attrs) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, { + "is_fp8_scale": True, **extra_weight_attrs}) - a2_scale = torch.nn.Parameter(torch.ones(num_experts, + w2_input_scale = torch.nn.Parameter(torch.ones(num_experts, dtype=torch.float32), requires_grad=False) - layer.register_parameter("a2_scale", a2_scale) - set_weight_attrs(a2_scale, extra_weight_attrs) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, { + "is_fp8_scale": True, **extra_weight_attrs}) else: - layer.a13_scale = None - layer.a2_scale = None + layer.w13_input_scale = None + layer.w2_input_scale = None def process_weights_after_loading(self, layer: Module) -> None: @@ -311,16 +315,16 @@ def process_weights_after_loading(self, layer: Module) -> None: # Re-initialize w13_scale because we directly quantize # merged w13 weights and generate a single scaling factor. - layer.w13_scale = torch.nn.Parameter(torch.ones( - layer.num_experts, - dtype=torch.float32, - device=w13_weight.device), - requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + torch.ones(layer.num_experts, + dtype=torch.float32, + device=w13_weight.device), + requires_grad=False) for expert in range(layer.num_experts): - w13_weight[expert, :, :], layer.w13_scale[ + w13_weight[expert, :, :], layer.w13_weight_scale[ expert] = ops.scaled_fp8_quant( layer.w13_weight.data[expert, :, :]) - w2_weight[expert, :, :], layer.w2_scale[ + w2_weight[expert, :, :], layer.w2_weight_scale[ expert] = ops.scaled_fp8_quant( layer.w2_weight.data[expert, :, :]) layer.w13_weight = torch.nn.Parameter(w13_weight, @@ -336,7 +340,7 @@ def process_weights_after_loading(self, layer: Module) -> None: # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. if self.quant_config.activation_scheme == "static": - if layer.a13_scale is None or layer.a2_scale is None: + if layer.w13_input_scale is None or layer.w2_input_scale is None: raise ValueError( "QuantConfig has static quantization, but found " "activation scales are None.") @@ -346,30 +350,30 @@ def process_weights_after_loading(self, layer: Module) -> None: "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " "for each layer. ") - layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(), + layer.w13_input_scale = torch.nn.Parameter(layer.w13_input_scale.max(), requires_grad=False) - layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(), + layer.w2_input_scale = torch.nn.Parameter(layer.w2_input_scale.max(), requires_grad=False) # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max then dequant and requant each expert. - assert layer.w13_scale is not None + assert layer.w13_weight_scale is not None shard_size = layer.intermediate_size_per_partition - max_w13_scales = layer.w13_scale.max(dim=1).values + max_w13_scales = layer.w13_weight_scale.max(dim=1).values for expert_id in range(layer.num_experts): start = 0 for shard_id in range(2): dq_weight = per_tensor_dequantize( layer.w13_weight[expert_id][start:start + shard_size, :], - layer.w13_scale[expert_id][shard_id]) + layer.w13_weight_scale[expert_id][shard_id]) layer.w13_weight[expert_id][ start:start + shard_size, :], _ = ops.scaled_fp8_quant( dq_weight, max_w13_scales[expert_id]) start += shard_size - layer.w13_scale = torch.nn.Parameter(max_w13_scales, - requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) return def apply(self, @@ -387,10 +391,10 @@ def apply(self, renormalize=renormalize, inplace=True, use_fp8=True, - w1_scale=layer.w13_scale, - w2_scale=layer.w2_scale, - a1_scale=layer.a13_scale, - a2_scale=layer.a2_scale) + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_sclae) class Fp8KVCacheMethod(QuantizeMethodBase): diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index f7e0f56c1a46..e57bec65fca4 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -25,6 +25,7 @@ def get_model_architecture( # FIXME(woosuk): This is a temporary hack. if (model_config.quantization is not None and model_config.quantization != "fp8" + and model_config.quantization != "awq" and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"] diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index e5bd58a9e97b..fc7d275e2482 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -375,25 +375,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): expert_params_mapping = [ # These are the weight scales for the experts # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_scale" - if weight_name in ["w1", "w3"] else "experts.w2_scale", - f"experts.{expert_id}.{weight_name}.weight_scale", expert_id, - shard_id) for expert_id in range(self.config.num_local_experts) - for shard_id, weight_name in enumerate(["w1", "w2", "w3"]) - ] + [ - # These are the weights for the experts - # (param_name, weight_name, expert_id) - ("experts.w13_weight" - if weight_name in ["w1", "w3"] else "experts.w2_weight", - f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id) - for expert_id in range(self.config.num_local_experts) - for shard_id, weight_name in enumerate(["w1", "w2", "w3"]) - ] + [ - # These are the activation scales for the experts - # (param_name, weight_name, expert_id) - ("experts.a13_scale" - if weight_name in ["w1", "w3"] else "experts.a2_scale", - f"experts.{expert_id}.{weight_name}.input_scale", expert_id, + ("experts.w13_" + if weight_name in ["w1", "w3"] else "experts.w2_", + f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) for expert_id in range(self.config.num_local_experts) for shard_id, weight_name in enumerate(["w1", "w2", "w3"]) ] From febb027c6e0bdc375909516920c3c958ba93a0b3 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 13 Jul 2024 20:33:55 +0000 Subject: [PATCH 07/29] loaded but not running at fp16 --- vllm/model_executor/layers/fused_moe/layer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 76865332a842..78f94e277d18 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -169,7 +169,8 @@ def weight_loader(self, param: torch.nn.Parameter, else: tp_rank = get_tensor_model_parallel_rank() - shard_size = self.intermediate_size_per_partition + is_gate_up = (shard_id == 0 or shard_id == 2) + shard_size = param_data.shape[2] // 2 if is_gate_up else param_data.shape[1] # If packed parameter (and packing is on the same dim as # TP sharding, adjust indexing by pack factor. From 8bca0091e86ddf8015a9929caadf07554f965ef9 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 13 Jul 2024 21:36:40 +0000 Subject: [PATCH 08/29] correctness end-to-end! --- vllm/model_executor/layers/fused_moe/layer.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 78f94e277d18..5edb4f371e58 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -170,13 +170,14 @@ def weight_loader(self, param: torch.nn.Parameter, tp_rank = get_tensor_model_parallel_rank() is_gate_up = (shard_id == 0 or shard_id == 2) - shard_size = param_data.shape[2] // 2 if is_gate_up else param_data.shape[1] - - # If packed parameter (and packing is on the same dim as - # TP sharding, adjust indexing by pack factor. - packed_dim = getattr(param, "packed_dim", None) - if packed_dim and shard_id != 1: - shard_size = shard_size // param.pack_factor + if is_gate_up: + shard_size = self.intermediate_size_per_partition + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == 1: + shard_size = shard_size // param.pack_factor + else: + shard_size = param_data.shape[1] + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) @@ -189,7 +190,7 @@ def weight_loader(self, param: torch.nn.Parameter, param_data[expert_id, :, 0:shard_size] = loaded_weight[:, shard] # w3, up_proj case: Load into second shard of w13. elif shard_id == 2: - param_data[expert_id, :, shard_size:2*shard_size] = loaded_weight[:, shard] + param_data[expert_id, :, shard_size:] = loaded_weight[:, shard] # w2, down_proj case: Load into only shard of w2. elif shard_id == 1: param_data[expert_id, :, :] = loaded_weight[shard, :] From 8527d6eaf95785b2305daf67ac61a09800bf8cc6 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 13 Jul 2024 21:39:41 +0000 Subject: [PATCH 09/29] formatted --- csrc/ops.h | 3 +- csrc/quantization/awq/gemm_kernels.cu | 4 +- vllm/_custom_ops.py | 22 +++--- .../layers/fused_moe/fused_moe_awq.py | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 22 +++--- .../model_executor/layers/quantization/awq.py | 70 +++++++++++-------- .../model_executor/layers/quantization/fp8.py | 59 +++++++++------- vllm/model_executor/models/mixtral.py | 7 +- 8 files changed, 101 insertions(+), 88 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index ca52b7aa7d5f..387d36eb4222 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -69,8 +69,7 @@ torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor awq_fused_moe(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, - torch::Tensor _zeros, - torch::Tensor _topk_weights, + torch::Tensor _zeros, torch::Tensor _topk_weights, torch::Tensor _sorted_token_ids_ptr, torch::Tensor _expert_ids_ptr, torch::Tensor _num_tokens_post_padded, diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 5619266e88c9..64266c5e9b5e 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -762,8 +762,8 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32( torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _scaling_factors, - torch::Tensor _zeros, int64_t split_k_iters, int64_t thx, - int64_t thy) { + torch::Tensor _zeros, int64_t split_k_iters, + int64_t thx, int64_t thy) { int in_c = _kernel.dim() == 2 ? _kernel.size(0) : _kernel.size(1); int qout_c = _kernel.dim() == 2 ? _kernel.size(1) : _kernel.size(2); int num_experts = _kernel.dim() == 2 ? 1 : _kernel.size(0); diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ee47d8f577a1..198cc1e6cfaf 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -177,21 +177,17 @@ def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters) - -def awq_fused_moe(input: torch.Tensor, - qweight: torch.Tensor, - scales: torch.Tensor, - qzeros: torch.Tensor, - topk_weights: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_padded: int, - mul_weights: bool, - pack_factor: int) -> torch.Tensor: - return torch.ops._C.awq_fused_moe(input, qweight, scales, qzeros, topk_weights, - sorted_token_ids, expert_ids, num_tokens_post_padded, +def awq_fused_moe(input: torch.Tensor, qweight: torch.Tensor, + scales: torch.Tensor, qzeros: torch.Tensor, + topk_weights: torch.Tensor, sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, num_tokens_post_padded: int, + mul_weights: bool, pack_factor: int) -> torch.Tensor: + return torch.ops._C.awq_fused_moe(input, qweight, scales, qzeros, + topk_weights, sorted_token_ids, + expert_ids, num_tokens_post_padded, mul_weights, pack_factor) + # gptq def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_awq.py b/vllm/model_executor/layers/fused_moe/fused_moe_awq.py index 649b67a6b788..6df580612c49 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_awq.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_awq.py @@ -47,7 +47,7 @@ def fused_moe_awq( # If large seq_len prefill, dequantize and use the fp16 MoE kernel. do_naive_dequant = hidden_states.shape[:-1].numel() >= 1024 if do_naive_dequant: - # TODO: why is this not contiguous alreayd? + # TODO: why is this not contiguous already? dequant_w1 = ops.awq_dequantize(w1, w1_scales, w1_qzeros, 0, 0, 0).permute(0, 2, 1).contiguous() dequant_w2 = ops.awq_dequantize(w2, w2_scales, w2_qzeros, 0, 0, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 5edb4f371e58..b9f063182c5e 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -134,15 +134,15 @@ def __init__( intermediate_size=self.intermediate_size_per_partition, params_dtype=params_dtype, weight_loader=self.weight_loader) - + def _load_fp8_scale(self, param_data: torch.Tensor, - loaded_weight: torch.Tensor, - weight_name: str, expert_id: int) -> None: + loaded_weight: torch.Tensor, weight_name: str, + expert_id: int) -> None: # FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral. # Follow up PR to enable fp8 for other MoE models. if "input_scale" in weight_name or "w2.weight_scale" in weight_name: if param_data[expert_id] != 1 and (param_data[expert_id] - - loaded_weight).abs() > 1e-5: + loaded_weight).abs() > 1e-5: raise ValueError( "input_scales of w1 and w3 of a layer " f"must be equal. But got {param_data[expert_id]} " @@ -164,7 +164,8 @@ def weight_loader(self, param: torch.nn.Parameter, # Special case for fp8 scales. if getattr(param, "is_fp8_scale", False): - self._load_fp8_scale(param_data, loaded_weight, weight_name, expert_id) + self._load_fp8_scale(param_data, loaded_weight, weight_name, + expert_id) # Otherwise, load with usual logic. else: tp_rank = get_tensor_model_parallel_rank() @@ -177,7 +178,6 @@ def weight_loader(self, param: torch.nn.Parameter, shard_size = shard_size // param.pack_factor else: shard_size = param_data.shape[1] - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) @@ -187,10 +187,12 @@ def weight_loader(self, param: torch.nn.Parameter, if is_transposed: # w1, gate_proj case: Load into first shard of w13. if shard_id == 0: - param_data[expert_id, :, 0:shard_size] = loaded_weight[:, shard] + param_data[expert_id, :, + 0:shard_size] = loaded_weight[:, shard] # w3, up_proj case: Load into second shard of w13. elif shard_id == 2: - param_data[expert_id, :, shard_size:] = loaded_weight[:, shard] + param_data[expert_id, :, + shard_size:] = loaded_weight[:, shard] # w2, down_proj case: Load into only shard of w2. elif shard_id == 1: param_data[expert_id, :, :] = loaded_weight[shard, :] @@ -201,11 +203,11 @@ def weight_loader(self, param: torch.nn.Parameter, # w1, gate_proj case: Load into first shard of w13. if shard_id == 0: param_data[expert_id, - 0:shard_size, :] = loaded_weight[shard, :] + 0:shard_size, :] = loaded_weight[shard, :] # w3, up_proj case: Load into second shard of w13. elif shard_id == 2: param_data[expert_id, shard_size:2 * - shard_size, :] = loaded_weight[shard, :] + shard_size, :] = loaded_weight[shard, :] # w2, down_proj case: Load into only shard of w2. elif shard_id == 1: param_data[expert_id, :, :] = loaded_weight[:, shard] diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 2cf97ac7dbfe..318c2a6bf087 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -8,7 +8,7 @@ fused_moe_awq) from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs @@ -66,7 +66,7 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": return cls(weight_bits, group_size, zero_point) def get_quant_method( - self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]: + self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): return AWQLinearMethod(self) elif isinstance(layer, FusedMoE): @@ -191,32 +191,34 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, # WEIGHTS w13_qweight = Parameter(torch.empty(num_experts, - hidden_size, - 2 * intermediate_size // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) + hidden_size, + 2 * intermediate_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) layer.register_parameter("w13_qweight", w13_qweight) set_weight_attrs( w13_qweight, { "packed_dim": 1, "pack_factor": self.quant_config.pack_factor, "is_transposed": True, - **extra_weight_attrs}) + **extra_weight_attrs + }) w2_qweight = Parameter(torch.empty(num_experts, - intermediate_size, - hidden_size // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) + intermediate_size, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) layer.register_parameter("w2_qweight", w2_qweight) set_weight_attrs( w2_qweight, { "packed_dim": 1, "pack_factor": self.quant_config.pack_factor, "is_transposed": True, - **extra_weight_attrs}) + **extra_weight_attrs + }) # WEIGHT_SCALES # Allocate 2 scales for w1 and w3 respectively. @@ -229,7 +231,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.register_parameter("w13_scales", w13_scales) set_weight_attrs(w13_scales, { "is_transposed": True, - **extra_weight_attrs}) + **extra_weight_attrs + }) w2_scales = Parameter(torch.empty(num_experts, intermediate_size // @@ -240,7 +243,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.register_parameter("w2_scales", w2_scales) set_weight_attrs(w2_scales, { "is_transposed": True, - **extra_weight_attrs}) + **extra_weight_attrs + }) # WEIGHT_ZERO_POINT # Allocate 2 zero points for w1 and w3 respectively. @@ -256,7 +260,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, "packed_dim": 1, "pack_factor": self.quant_config.pack_factor, "is_transposed": True, - **extra_weight_attrs}) + **extra_weight_attrs + }) w2_qzeros = Parameter(torch.empty( num_experts, @@ -270,23 +275,26 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, "packed_dim": 1, "pack_factor": self.quant_config.pack_factor, "is_transposed": True, - **extra_weight_attrs}) - + **extra_weight_attrs + }) + def apply(self, layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, renormalize: bool = True) -> torch.Tensor: - - return fused_moe_awq(x, - layer.w13_qweight, - layer.w2_qweight, - router_logits, - top_k, - renormalize=renormalize, - pack_factor=self.quant_config.pack_factor, - w1_scales=layer.w13_scales, - w2_scales=layer.w2_scales, - w1_qzeros=layer.w13_qzeros, - w2_qzeros=layer.w2_qzeros,) + + return fused_moe_awq( + x, + layer.w13_qweight, + layer.w2_qweight, + router_logits, + top_k, + renormalize=renormalize, + pack_factor=self.quant_config.pack_factor, + w1_scales=layer.w13_scales, + w2_scales=layer.w2_scales, + w1_qzeros=layer.w13_qzeros, + w2_qzeros=layer.w2_qzeros, + ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 23ca924d37aa..a446f99ffc41 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -261,14 +261,14 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, # Allocate 2 scales for w1 and w3 respectively. # They will be combined to a single scale after weight loading. w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - 2, - dtype=torch.float32), - requires_grad=False) + 2, + dtype=torch.float32), + requires_grad=False) layer.register_parameter("w13_weight_scale", w13_weight_scale) w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) + dtype=torch.float32), + requires_grad=False) layer.register_parameter("w13_weight_scale", w13_weight_scale) # If loading fp8 checkpoint, pass the weight loaders. @@ -276,9 +276,13 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, # process_weights_after_loading() if self.quant_config.is_checkpoint_fp8_serialized: set_weight_attrs(w13_weight_scale, { - "is_fp8_scale": True, **extra_weight_attrs}) + "is_fp8_scale": True, + **extra_weight_attrs + }) set_weight_attrs(w13_weight_scale, { - "is_fp8_scale": True, **extra_weight_attrs}) + "is_fp8_scale": True, + **extra_weight_attrs + }) # INPUT_SCALES if self.quant_config.activation_scheme == "static": @@ -287,19 +291,23 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, "Found static activation scheme for checkpoint that " "was not serialized fp8.") - w13_input_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) layer.register_parameter("w13_input_scale", w13_input_scale) set_weight_attrs(w13_input_scale, { - "is_fp8_scale": True, **extra_weight_attrs}) + "is_fp8_scale": True, + **extra_weight_attrs + }) - w2_input_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) layer.register_parameter("w2_input_scale", w2_input_scale) set_weight_attrs(w2_input_scale, { - "is_fp8_scale": True, **extra_weight_attrs}) + "is_fp8_scale": True, + **extra_weight_attrs + }) else: layer.w13_input_scale = None layer.w2_input_scale = None @@ -315,11 +323,11 @@ def process_weights_after_loading(self, layer: Module) -> None: # Re-initialize w13_scale because we directly quantize # merged w13 weights and generate a single scaling factor. - layer.w13_weight_scale = torch.nn.Parameter( - torch.ones(layer.num_experts, - dtype=torch.float32, - device=w13_weight.device), - requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(torch.ones( + layer.num_experts, + dtype=torch.float32, + device=w13_weight.device), + requires_grad=False) for expert in range(layer.num_experts): w13_weight[expert, :, :], layer.w13_weight_scale[ expert] = ops.scaled_fp8_quant( @@ -340,7 +348,8 @@ def process_weights_after_loading(self, layer: Module) -> None: # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. if self.quant_config.activation_scheme == "static": - if layer.w13_input_scale is None or layer.w2_input_scale is None: + if (layer.w13_input_scale is None + or layer.w2_input_scale is None): raise ValueError( "QuantConfig has static quantization, but found " "activation scales are None.") @@ -350,10 +359,10 @@ def process_weights_after_loading(self, layer: Module) -> None: "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " "for each layer. ") - layer.w13_input_scale = torch.nn.Parameter(layer.w13_input_scale.max(), - requires_grad=False) - layer.w2_input_scale = torch.nn.Parameter(layer.w2_input_scale.max(), - requires_grad=False) + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False) # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max then dequant and requant each expert. diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index fc7d275e2482..62223559d4fb 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -375,10 +375,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): expert_params_mapping = [ # These are the weight scales for the experts # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_" - if weight_name in ["w1", "w3"] else "experts.w2_", - f"experts.{expert_id}.{weight_name}.", expert_id, - shard_id) for expert_id in range(self.config.num_local_experts) + ("experts.w13_" if weight_name in ["w1", "w3"] else "experts.w2_", + f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) + for expert_id in range(self.config.num_local_experts) for shard_id, weight_name in enumerate(["w1", "w2", "w3"]) ] From 36d1d82aef39708b6db6d822d2894f547fa91844 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 13 Jul 2024 22:26:18 +0000 Subject: [PATCH 10/29] updared the weight loading logic --- vllm/model_executor/layers/fused_moe/layer.py | 96 +++++++++---------- 1 file changed, 45 insertions(+), 51 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b9f063182c5e..bc863fcacc5c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -135,9 +135,11 @@ def __init__( params_dtype=params_dtype, weight_loader=self.weight_loader) - def _load_fp8_scale(self, param_data: torch.Tensor, + def _load_fp8_scale(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, expert_id: int) -> None: + param_data = param.data + # FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral. # Follow up PR to enable fp8 for other MoE models. if "input_scale" in weight_name or "w2.weight_scale" in weight_name: @@ -160,60 +162,52 @@ def _load_fp8_scale(self, param_data: torch.Tensor, def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: int, expert_id: int) -> None: - param_data = param.data + if shard_id not in [0,1,2]: + raise ValueError(f"Shard id must be in [0,1,2] but got {shard_id}") # Special case for fp8 scales. if getattr(param, "is_fp8_scale", False): - self._load_fp8_scale(param_data, loaded_weight, weight_name, + self._load_fp8_scale(param.data, loaded_weight, weight_name, expert_id) - # Otherwise, load with usual logic. - else: - tp_rank = get_tensor_model_parallel_rank() - - is_gate_up = (shard_id == 0 or shard_id == 2) - if is_gate_up: - shard_size = self.intermediate_size_per_partition - packed_dim = getattr(param, "packed_dim", None) - if packed_dim == 1: - shard_size = shard_size // param.pack_factor - else: - shard_size = param_data.shape[1] - - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - - # Usually, weight is saved in format [output_dim, input_dim] - # If transposed, weight is saved in format [input_dim, output_dim] - is_transposed = getattr(param, "is_transposed", False) - if is_transposed: - # w1, gate_proj case: Load into first shard of w13. - if shard_id == 0: - param_data[expert_id, :, - 0:shard_size] = loaded_weight[:, shard] - # w3, up_proj case: Load into second shard of w13. - elif shard_id == 2: - param_data[expert_id, :, - shard_size:] = loaded_weight[:, shard] - # w2, down_proj case: Load into only shard of w2. - elif shard_id == 1: - param_data[expert_id, :, :] = loaded_weight[shard, :] - else: - raise ValueError( - f"Shard id must be in [0,1,2] but got {shard_id}") - else: - # w1, gate_proj case: Load into first shard of w13. - if shard_id == 0: - param_data[expert_id, - 0:shard_size, :] = loaded_weight[shard, :] - # w3, up_proj case: Load into second shard of w13. - elif shard_id == 2: - param_data[expert_id, shard_size:2 * - shard_size, :] = loaded_weight[shard, :] - # w2, down_proj case: Load into only shard of w2. - elif shard_id == 1: - param_data[expert_id, :, :] = loaded_weight[:, shard] - else: - raise ValueError( - f"Shard id must be in [0,1,2] but got {shard_id}") + return + + expert = param.data[expert_id] + tp_rank = get_tensor_model_parallel_rank() + is_gate_proj = (shard_id == 0) + is_down_proj = (shard_id == 1) + is_up_proj = (shard_id == 2) + + # If transposed, weight is saved as [input_dim, output_dim] + # Otherwise, weight is saved as [output_dim, input_dim] + is_transposed = getattr(param, "is_transposed", False) + input_dim = 0 if is_transposed else 1 + output_dim = 1 if is_transposed else 0 + + # Index the loaded weight for tp sharding. + # * down_proj: "RowParallel" so tp sharding on input_dim + if (is_down_proj): + shard_dim = input_dim + shard_size = expert.shape[shard_dim] + # * gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim + elif (is_gate_proj or is_up_proj): + shard_dim = output_dim + shard_size = expert.shape[output_dim] // 2 + offset = shard_size * tp_rank + loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size) + + # Narrow parameter and load. + # w1, gate_proj case: Load into first shard of w13. + if is_gate_proj: + expert = expert.narrow(shard_dim, 0, shard_size) + expert.copy_(loaded_weight) + # w3, up_proj case: Load into second shard of w13. + elif is_up_proj: + expert = expert.narrow(shard_dim, shard_size, shard_size) + expert.copy_(loaded_weight) + # w2, down_proj case: Load into only shard of w2. + elif is_down_proj: + expert.copy_(loaded_weight) + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): From 6943e80ba49fbb7ef1daeebe86050a4bec8de7b7 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 13 Jul 2024 23:07:04 +0000 Subject: [PATCH 11/29] stash --- vllm/model_executor/layers/fused_moe/layer.py | 26 ++++++++++--------- .../model_executor/layers/quantization/fp8.py | 16 ++++++------ vllm/model_executor/models/mixtral.py | 2 +- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index bc863fcacc5c..22eb447a9cf2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -142,7 +142,7 @@ def _load_fp8_scale(self, param: torch.nn.Parameter, # FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral. # Follow up PR to enable fp8 for other MoE models. - if "input_scale" in weight_name or "w2.weight_scale" in weight_name: + if "input_scale" in weight_name or "w2_weight_scale" in weight_name: if param_data[expert_id] != 1 and (param_data[expert_id] - loaded_weight).abs() > 1e-5: raise ValueError( @@ -171,7 +171,7 @@ def weight_loader(self, param: torch.nn.Parameter, expert_id) return - expert = param.data[expert_id] + expert_data = param.data[expert_id] tp_rank = get_tensor_model_parallel_rank() is_gate_proj = (shard_id == 0) is_down_proj = (shard_id == 1) @@ -187,26 +187,28 @@ def weight_loader(self, param: torch.nn.Parameter, # * down_proj: "RowParallel" so tp sharding on input_dim if (is_down_proj): shard_dim = input_dim - shard_size = expert.shape[shard_dim] + shard_size = expert_data.shape[shard_dim] # * gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim elif (is_gate_proj or is_up_proj): shard_dim = output_dim - shard_size = expert.shape[output_dim] // 2 + shard_size = expert_data.shape[output_dim] // 2 offset = shard_size * tp_rank loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size) # Narrow parameter and load. - # w1, gate_proj case: Load into first shard of w13. + # w1, gate_proj: Load into first shard of w13. if is_gate_proj: - expert = expert.narrow(shard_dim, 0, shard_size) - expert.copy_(loaded_weight) - # w3, up_proj case: Load into second shard of w13. + expert_data = expert_data.narrow(shard_dim, 0, shard_size) + expert_data.copy_(loaded_weight) + # w3, up_proj: Load into second shard of w13. elif is_up_proj: - expert = expert.narrow(shard_dim, shard_size, shard_size) - expert.copy_(loaded_weight) - # w2, down_proj case: Load into only shard of w2. + expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + expert_data.copy_(loaded_weight) + # w2, down_proj: Load into only shard of w2. elif is_down_proj: - expert.copy_(loaded_weight) + expert_data.copy_(loaded_weight) + else: + raise ValueError def forward(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index a446f99ffc41..a28603dca04d 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -266,10 +266,10 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, requires_grad=False) layer.register_parameter("w13_weight_scale", w13_weight_scale) - w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) # If loading fp8 checkpoint, pass the weight loaders. # If loading an fp16 checkpoint, do not (we will quantize in @@ -279,7 +279,7 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, "is_fp8_scale": True, **extra_weight_attrs }) - set_weight_attrs(w13_weight_scale, { + set_weight_attrs(w2_weight_scale, { "is_fp8_scale": True, **extra_weight_attrs }) @@ -353,8 +353,8 @@ def process_weights_after_loading(self, layer: Module) -> None: raise ValueError( "QuantConfig has static quantization, but found " "activation scales are None.") - if (not all_close_1d(layer.a13_scale) - or not all_close_1d(layer.a2_scale)): + if (not all_close_1d(layer.w13_input_scale) + or not all_close_1d(layer.w2_input_scale)): print_warning_once( "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " @@ -403,7 +403,7 @@ def apply(self, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_sclae) + a2_scale=layer.w2_input_scale) class Fp8KVCacheMethod(QuantizeMethodBase): diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 62223559d4fb..76cdddbdf242 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -407,7 +407,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = param.weight_loader weight_loader(param, loaded_weight, - weight_name, + name, shard_id=shard_id, expert_id=expert_id) break From 71e5129904db48d0700b1e4f03a7daf939107b4d Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 13 Jul 2024 23:13:51 +0000 Subject: [PATCH 12/29] fixed fp8 --- vllm/model_executor/layers/fused_moe/layer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 22eb447a9cf2..a1f52e4ff67f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -137,7 +137,7 @@ def __init__( def _load_fp8_scale(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, - expert_id: int) -> None: + shard_id: int, expert_id: int) -> None: param_data = param.data # FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral. @@ -155,9 +155,9 @@ def _load_fp8_scale(self, param: torch.nn.Parameter, elif "weight_scale" in weight_name: # We have to keep the weight scales of w1 and w3 because # we need to re-quantize w1/w3 weights after weight loading. - assert "w1" in weight_name or "w3" in weight_name - shard_id = 0 if "w1" in weight_name else 1 - param_data[expert_id][shard_id] = loaded_weight + assert shard_id == 0 or shard_id == 2 + shard_idx = 0 if shard_id == 0 else 1 + param_data[expert_id][shard_idx] = loaded_weight def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, @@ -168,7 +168,7 @@ def weight_loader(self, param: torch.nn.Parameter, # Special case for fp8 scales. if getattr(param, "is_fp8_scale", False): self._load_fp8_scale(param.data, loaded_weight, weight_name, - expert_id) + shard_id, expert_id) return expert_data = param.data[expert_id] From 5b73064b602b5608522b673a1eac8faaca2db1c6 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 14 Jul 2024 11:30:28 +0000 Subject: [PATCH 13/29] merged --- vllm/model_executor/layers/fused_moe/layer.py | 35 +++++-------------- vllm/model_executor/models/mixtral.py | 1 - 2 files changed, 8 insertions(+), 28 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 30895a08c6f0..4a50bf4dfcd5 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -212,33 +212,14 @@ def make_expert_params_mapping( ckpt_up_proj_name: str, num_experts: int) -> List[Tuple[str, str, int, int]]: - gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name] - gate_down_up = [ - ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name - ] - return [ - # These are the weight scales for the experts - # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_scale" - if weight_name in gate_up else "experts.w2_scale", - f"experts.{expert_id}.{weight_name}.weight_scale", expert_id, - shard_id) for expert_id in range(num_experts) - for shard_id, weight_name in enumerate(gate_down_up) - ] + [ - # These are the weights for the experts # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_weight" - if weight_name in gate_up else "experts.w2_weight", - f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id) + ("experts.w13_" if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", + f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) for expert_id in range(num_experts) - for shard_id, weight_name in enumerate(gate_down_up) - ] + [ - # These are the weight scales for the experts - # (param_name, weight_name, expert_id, shard_id) - ("experts.a13_scale" - if weight_name in gate_up else "experts.a2_scale", - f"experts.{expert_id}.{weight_name}.input_scale", expert_id, - shard_id) for expert_id in range(num_experts) - for shard_id, weight_name in enumerate(gate_down_up) - ] + for shard_id, weight_name in enumerate([ + ckpt_gate_proj_name, + ckpt_down_proj_name, + ckpt_up_proj_name, + ]) + ] \ No newline at end of file diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index ae81e0bc3035..35e0dca5c1e7 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -372,7 +372,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] - # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="w1", From 2ef2c92f3e62480ce5bc370ac3593dae5db144a8 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 14 Jul 2024 12:05:21 +0000 Subject: [PATCH 14/29] formatting --- .../layers/fused_moe/__init__.py | 5 +- .../layers/fused_moe/fused_moe_awq.py | 32 ++-- vllm/model_executor/layers/fused_moe/layer.py | 168 +++++++++++------- .../model_executor/layers/quantization/awq.py | 31 ++-- .../model_executor/layers/quantization/fp8.py | 42 ++--- 5 files changed, 150 insertions(+), 128 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 90208f4f6736..dd9c0a71513e 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,14 +1,15 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) -from vllm.model_executor.layers.fused_moe.fused_moe_awq import fused_moe_awq +from vllm.model_executor.layers.fused_moe.fused_moe_awq import ( + fused_experts_awq) from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, FusedMoEMethodBase) __all__ = [ "fused_moe", - "fused_moe_awq", "fused_topk", "fused_experts", + "fused_experts_awq", "get_config_file_name", "grouped_topk", "FusedMoE", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_awq.py b/vllm/model_executor/layers/fused_moe/fused_moe_awq.py index 6df580612c49..a0e3ce33a248 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_awq.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_awq.py @@ -4,48 +4,44 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger -from .fused_moe import fused_moe, fused_topk, moe_align_block_size +from .fused_moe import fused_experts, moe_align_block_size logger = init_logger(__name__) +NAIVE_THRESHOLD = 1024 -def fused_moe_awq( + +def fused_experts_awq( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - pack_factor: int, w1_scales: torch.Tensor, w2_scales: torch.Tensor, w1_qzeros: torch.Tensor, w2_qzeros: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + pack_factor: int, ) -> torch.Tensor: """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. + This function computes an AWQ fused_expert. Parameters: - hidden_states (torch.Tensor): The input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - pack_factor (int): Weight packing factor (int4 in int32 == 8) + - w2 (torch.Tensor): The second set of expert weights. - w1_scales (torch.Tensor): scale to be used for w1. - w2_scales (torch.Tensor): scale to be used for w2. - w1_qzeros (torch.Tensor): zero point to be used for w1. - w2_qzeros (torch.Tensor): zero point to be used for w2. + - pack_factor (int): Weight packing factor (int4 in int32 == 8) Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ # If large seq_len prefill, dequantize and use the fp16 MoE kernel. - do_naive_dequant = hidden_states.shape[:-1].numel() >= 1024 + do_naive_dequant = hidden_states.shape[:-1].numel() >= NAIVE_THRESHOLD if do_naive_dequant: # TODO: why is this not contiguous already? dequant_w1 = ops.awq_dequantize(w1, w1_scales, w1_qzeros, 0, 0, @@ -53,11 +49,9 @@ def fused_moe_awq( dequant_w2 = ops.awq_dequantize(w2, w2_scales, w2_qzeros, 0, 0, 0).permute(0, 2, 1).contiguous() - return fused_moe(hidden_states, dequant_w1, dequant_w2, gating_output, - topk, renormalize) + return fused_experts(hidden_states, dequant_w1, dequant_w2, + topk_weights, topk_ids) - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) (sorted_token_ids, expert_ids, num_tokens_post_padded) = moe_align_block_size(topk_ids, 16, w1.shape[0]) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 4a50bf4dfcd5..c8f30fcc8123 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -7,7 +7,9 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, + fused_topk, + grouped_topk) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs @@ -24,15 +26,9 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, raise NotImplementedError @abstractmethod - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: + def apply(self, layer: torch.nn.Module, x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor) -> torch.Tensor: raise NotImplementedError @@ -61,26 +57,16 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: - - return fused_moe(x, - layer.w13_weight, - layer.w2_weight, - router_logits, - top_k, - renormalize=renormalize, - inplace=True, - use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, - topk_group=topk_group) + def apply(self, layer: torch.nn.Module, x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor) -> torch.Tensor: + + return fused_experts(x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + inplace=True) class FusedMoE(torch.nn.Module): @@ -152,6 +138,25 @@ def __init__( params_dtype=params_dtype, weight_loader=self.weight_loader) + @classmethod + def make_expert_params_mapping( + cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int) -> List[Tuple[str, str, int, int]]: + + return [ + # (param_name, weight_name, expert_id, shard_id) + ("experts.w13_" if weight_name + in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", + f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) + for expert_id in range(num_experts) + for shard_id, weight_name in enumerate([ + ckpt_gate_proj_name, + ckpt_down_proj_name, + ckpt_up_proj_name, + ]) + ] + def _load_fp8_scale(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: int, expert_id: int) -> None: @@ -180,46 +185,89 @@ def _load_fp8_scale(self, param: torch.nn.Parameter, # shard_id 1 == down_proj / w2 else: param_data[expert_id] = loaded_weight - # Weights + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, weight_name: str, + shard_id: int, expert_id: int) -> None: + if shard_id not in [0, 1, 2]: + raise ValueError(f"Shard id must be in [0,1,2] but got {shard_id}") + + # Special case for fp8 scales. + if getattr(param, "is_fp8_scale", False): + self._load_fp8_scale(param.data, loaded_weight, weight_name, + shard_id, expert_id) + return + + expert_data = param.data[expert_id] + tp_rank = get_tensor_model_parallel_rank() + is_gate_proj = (shard_id == 0) + is_down_proj = (shard_id == 1) + is_up_proj = (shard_id == 2) + + # If transposed, weight is saved as [input_dim, output_dim] + # Otherwise, weight is saved as [output_dim, input_dim] + is_transposed = getattr(param, "is_transposed", False) + input_dim = 0 if is_transposed else 1 + output_dim = 1 if is_transposed else 0 + + # Index the loaded weight for tp sharding. + # * down_proj: "RowParallel" so tp sharding on input_dim + if (is_down_proj): + shard_dim = input_dim + shard_size = expert_data.shape[shard_dim] + # * gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim + elif (is_gate_proj or is_up_proj): + shard_dim = output_dim + shard_size = expert_data.shape[output_dim] // 2 + offset = shard_size * tp_rank + loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size) + + # Narrow parameter and load. + # w1, gate_proj: Load into first shard of w13. + if is_gate_proj: + expert_data = expert_data.narrow(shard_dim, 0, shard_size) + expert_data.copy_(loaded_weight) + # w3, up_proj: Load into second shard of w13. + elif is_up_proj: + expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + expert_data.copy_(loaded_weight) + # w2, down_proj: Load into only shard of w2. + elif is_down_proj: + expert_data.copy_(loaded_weight) else: raise ValueError - + + def _select_experts(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + # DeekSeekv2 uses grouped_top_k + if self.use_grouped_topk: + assert (self.num_expert_group is not None + and self.topk_group is not None) + topk_weights, topk_ids = grouped_topk(hidden_states, router_logits, + self.top_k, self.renormalize, + self.num_expert_group, + self.topk_group) + else: + topk_weights, topk_ids = fused_topk(hidden_states, router_logits, + self.top_k, self.renormalize) + + return topk_weights, topk_ids def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None - # Matrix multiply. - final_hidden_states = self.quant_method.apply( - self, - x=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - renormalize=self.renormalize, - use_grouped_topk=self.use_grouped_topk, - num_expert_group=self.num_expert_group, - topk_group=self.topk_group) + # Select experts. + topk_weights, topk_ids = self._select_experts(hidden_states, + router_logits) + + # Call fused kernel. + final_hidden_states = self.quant_method.apply(self, hidden_states, + topk_weights, topk_ids) + # Optionally reduce. if self.reduce_results and self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) return final_hidden_states - - @classmethod - def make_expert_params_mapping( - cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, - ckpt_up_proj_name: str, - num_experts: int) -> List[Tuple[str, str, int, int]]: - - return [ - # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_" if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", - f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) - for expert_id in range(num_experts) - for shard_id, weight_name in enumerate([ - ckpt_gate_proj_name, - ckpt_down_proj_name, - ckpt_up_proj_name, - ]) - ] \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 318c2a6bf087..b1f8d1e300b7 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -5,7 +5,7 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, - fused_moe_awq) + fused_experts_awq) from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) @@ -278,23 +278,12 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, **extra_weight_attrs }) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True) -> torch.Tensor: - - return fused_moe_awq( - x, - layer.w13_qweight, - layer.w2_qweight, - router_logits, - top_k, - renormalize=renormalize, - pack_factor=self.quant_config.pack_factor, - w1_scales=layer.w13_scales, - w2_scales=layer.w2_scales, - w1_qzeros=layer.w13_qzeros, - w2_qzeros=layer.w2_qzeros, - ) + def apply(self, layer: torch.nn.Module, x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor) -> torch.Tensor: + + return fused_experts_awq(x, layer.w13_qweight, layer.w2_qweight, + layer.w13_scales, layer.w2_scales, + layer.w13_qzeros, layer.w2_qzeros, + topk_weights, topk_ids, + self.quant_config.pack_factor) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index e1123b6bfcf8..b00a15f4e421 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -7,7 +7,7 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, - fused_moe) + fused_experts) from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) @@ -385,31 +385,21 @@ def process_weights_after_loading(self, layer: Module) -> None: requires_grad=False) return - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: - - return fused_moe(x, - layer.w13_weight, - layer.w2_weight, - router_logits, - top_k, - renormalize=renormalize, - inplace=True, - use_fp8=True, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, - topk_group=topk_group) + def apply(self, layer: torch.nn.Module, x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor) -> torch.Tensor: + + return fused_experts(x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8=True, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) class Fp8KVCacheMethod(QuantizeMethodBase): From db33c3fca1c014a517fa3ceddce60bc18f50e269 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 14 Jul 2024 12:12:23 +0000 Subject: [PATCH 15/29] better comments --- vllm/model_executor/layers/fused_moe/layer.py | 51 +++++++++---------- 1 file changed, 23 insertions(+), 28 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c8f30fcc8123..ca433fab568d 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -142,7 +142,7 @@ def __init__( def make_expert_params_mapping( cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, ckpt_up_proj_name: str, - num_experts: int) -> List[Tuple[str, str, int, int]]: + num_experts: int) -> List[Tuple[str, str, int, str]]: return [ # (param_name, weight_name, expert_id, shard_id) @@ -150,16 +150,16 @@ def make_expert_params_mapping( in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) for expert_id in range(num_experts) - for shard_id, weight_name in enumerate([ - ckpt_gate_proj_name, - ckpt_down_proj_name, - ckpt_up_proj_name, - ]) + for shard_id, weight_name in [ + ("w1", ckpt_gate_proj_name), + ("w2", ckpt_down_proj_name), + ("w3", ckpt_up_proj_name), + ] ] def _load_fp8_scale(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, - shard_id: int, expert_id: int) -> None: + shard_id: str, expert_id: int) -> None: param_data = param.data # Input scales can be loaded directly and should be equal. @@ -174,23 +174,21 @@ def _load_fp8_scale(self, param: torch.nn.Parameter, # Weight scales elif "weight_scale" in weight_name: # If we are in merged column case (gate_up_proj) - # shard_id 0 == gate_proj / w1 - # shard_id 2 == up_proj / w3 - if shard_id == 0 or shard_id == 2: + if shard_id == "w1" or shard_id == "w3": # We have to keep the weight scales of w1 and w3 because # we need to re-quantize w1/w3 weights after weight loading. - idx = 0 if shard_id == 0 else 1 + idx = 0 if shard_id == "w1" else 1 param_data[expert_id][idx] = loaded_weight # If we are in the row parallel case (down_proj) - # shard_id 1 == down_proj / w2 else: param_data[expert_id] = loaded_weight def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, - shard_id: int, expert_id: int) -> None: - if shard_id not in [0, 1, 2]: - raise ValueError(f"Shard id must be in [0,1,2] but got {shard_id}") + shard_id: str, expert_id: int) -> None: + if shard_id not in ["w1", "w2", "w3"]: + raise ValueError(f"shard_id must be ['w1','w2','w3'] but " + f"got {shard_id}.") # Special case for fp8 scales. if getattr(param, "is_fp8_scale", False): @@ -200,9 +198,6 @@ def weight_loader(self, param: torch.nn.Parameter, expert_data = param.data[expert_id] tp_rank = get_tensor_model_parallel_rank() - is_gate_proj = (shard_id == 0) - is_down_proj = (shard_id == 1) - is_up_proj = (shard_id == 2) # If transposed, weight is saved as [input_dim, output_dim] # Otherwise, weight is saved as [output_dim, input_dim] @@ -211,28 +206,28 @@ def weight_loader(self, param: torch.nn.Parameter, output_dim = 1 if is_transposed else 0 # Index the loaded weight for tp sharding. - # * down_proj: "RowParallel" so tp sharding on input_dim - if (is_down_proj): + # down_proj: "RowParallel" so tp sharding on input_dim + if (shard_id == "w2"): shard_dim = input_dim shard_size = expert_data.shape[shard_dim] - # * gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim - elif (is_gate_proj or is_up_proj): + # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim + elif (shard_id == "w1" or shard_id == "w3"): shard_dim = output_dim shard_size = expert_data.shape[output_dim] // 2 offset = shard_size * tp_rank loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size) # Narrow parameter and load. - # w1, gate_proj: Load into first shard of w13. - if is_gate_proj: + # w1, gate_proj: Load into first logical weight of w13. + if shard_id == "w1": expert_data = expert_data.narrow(shard_dim, 0, shard_size) expert_data.copy_(loaded_weight) - # w3, up_proj: Load into second shard of w13. - elif is_up_proj: + # w3, up_proj: Load into second logical weight of w13. + elif shard_id == "w3": expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) expert_data.copy_(loaded_weight) - # w2, down_proj: Load into only shard of w2. - elif is_down_proj: + # w2, down_proj: Load into only logical weight of w2. + elif shard_id == "w2": expert_data.copy_(loaded_weight) else: raise ValueError From f6f60cd166c61a7aa8683050629955565f7c8ddd Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 14 Jul 2024 12:33:50 +0000 Subject: [PATCH 16/29] added --- .../configs/Mixtral-8x7B-Instruct-v0.1-AWQ.yaml | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 .buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-AWQ.yaml diff --git a/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-AWQ.yaml b/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-AWQ.yaml new file mode 100644 index 000000000000..c36feb74682d --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-AWQ.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m casperhansen/mixtral-instruct-awq -b auto -l 1000 -f 5 -t 2 +model_name: "casperhansen/mixtral-instruct-awq" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.601 + - name: "exact_match,flexible-extract" + value: 0.599 +limit: 1000 +num_fewshot: 5 From d9def7e3e7b6e45ee084dd274d50296efad700b3 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 14 Jul 2024 12:59:31 +0000 Subject: [PATCH 17/29] formatted --- .../configs/Mixtral-8x7B-Instruct-v0.1-AWQ.yaml | 11 ----------- vllm/model_executor/layers/fused_moe/layer.py | 3 +-- 2 files changed, 1 insertion(+), 13 deletions(-) delete mode 100644 .buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-AWQ.yaml diff --git a/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-AWQ.yaml b/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-AWQ.yaml deleted file mode 100644 index c36feb74682d..000000000000 --- a/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-AWQ.yaml +++ /dev/null @@ -1,11 +0,0 @@ -# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m casperhansen/mixtral-instruct-awq -b auto -l 1000 -f 5 -t 2 -model_name: "casperhansen/mixtral-instruct-awq" -tasks: -- name: "gsm8k" - metrics: - - name: "exact_match,strict-match" - value: 0.601 - - name: "exact_match,flexible-extract" - value: 0.599 -limit: 1000 -num_fewshot: 5 diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index ca433fab568d..b67f823a5393 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -149,8 +149,7 @@ def make_expert_params_mapping( ("experts.w13_" if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) - for expert_id in range(num_experts) - for shard_id, weight_name in [ + for expert_id in range(num_experts) for shard_id, weight_name in [ ("w1", ckpt_gate_proj_name), ("w2", ckpt_down_proj_name), ("w3", ckpt_up_proj_name), From 16eacd0d47f1c6b16f6210df38decf42cc81de0f Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 14 Jul 2024 13:20:28 +0000 Subject: [PATCH 18/29] stash --- .../run-lm-eval-gsm-vllm-baseline.sh | 2 +- tests/kernels/test_moe.py | 79 +++++++++++++++++++ 2 files changed, 80 insertions(+), 1 deletion(-) diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh index dbb21be4f86e..2f04cc1283df 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh @@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do done lm_eval --model vllm \ - --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray",trust_remote_code=true \ + --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray",trust_remote_code=true,max_model_len=4096 \ --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \ --batch_size $BATCH_SIZE diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 2f9eee420f27..3deb744a3d5d 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -99,3 +99,82 @@ def test_mixtral_moe(dtype: torch.dtype): vllm_states, rtol=mixtral_moe_tol[dtype], atol=mixtral_moe_tol[dtype]) + +def torch_moe_awq(a, w1, w1_scale, w1_zero, w2, w2_scale, w2_zero, score, + topk): + score = torch.softmax(score.float(), dim=-1) + topk_weight, topk_ids = torch.topk(score, topk) + (B, D) = a.shape + a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) + out = torch.zeros(B * topk_ids.shape[1], + w2.shape[2] * 8, + dtype=a.dtype, + device=a.device) + topk_ids = topk_ids.view(-1) + topk_weight = topk_weight.view(-1) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + dw1 = ops.awq_dequantize(w1[i], w1_scale[i], w1_zero[i], 0, 0, 0) + dw2 = ops.awq_dequantize(w2[i], w2_scale[i], w2_zero[i], 0, 0, 0) + r1 = SiluAndMul()(torch.matmul(a[mask].half(), dw1)) + out[mask] = torch.matmul(r1, dw2).to(out.dtype) + return (out.view(B, -1, w2.shape[2] * 8) * + topk_weight.view(B, -1, 1)).sum(dim=1).half() + + +@pytest.mark.parametrize("m", [1024, 222, 33, 1]) +@pytest.mark.parametrize("n", [2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 512, 1024]) +@pytest.mark.parametrize("e", [8]) +@pytest.mark.parametrize("topk", [2, 6]) +def test_fused_moe_awq( + m: int, + n: int, + k: int, + e: int, + topk: int, +): + # awq requires minimum capability 75 + if torch.version.hip is not None: + return + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + if capability < 75: + return + + RANGE = 1000000000 + groupsize = 128 + a = torch.randn((m, k), device='cuda', dtype=torch.half) / 10 + qw1 = torch.randint(-RANGE, + RANGE, (e, k, n * 2 // 8), + dtype=torch.int, + device='cuda') + qw2 = torch.randint(-RANGE, + RANGE, (e, n, k // 8), + dtype=torch.int, + device='cuda') + + scale1 = torch.randn( + (e, k // groupsize, n * 2), dtype=torch.half, device='cuda') / 50 + scale2 = torch.randn( + (e, n // groupsize, k), dtype=torch.half, device='cuda') / 50 + + zero1 = torch.randint(-RANGE, + RANGE, (e, k // groupsize, (n * 2 // 32) * 4), + dtype=torch.int32, + device='cuda') + zero2 = torch.randint(-RANGE, + RANGE, (e, n // groupsize, (k // 32) * 4), + dtype=torch.int32, + device='cuda') + w1 = {"qweight": qw1, "scales": scale1, "qzeros": zero1} + w2 = {"qweight": qw2, "scales": scale2, "qzeros": zero2} + + score = torch.randn((m, e), device='cuda', dtype=torch.half) + + awq_method = AWQLinearMethod(AWQConfig(4, groupsize, False)) + torch_output = torch_moe_awq(a, qw1, scale1, zero1, qw2, scale2, zero2, + score, topk) + cuda_output = awq_method.apply_moe_weights(w1, w2, a, score, topk, False) + assert torch.allclose(cuda_output, torch_output, atol=1e-2, rtol=0) From d6a032efd8d2d8227c6dbe05a6770efc718db8ce Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 30 Jul 2024 17:06:49 +0000 Subject: [PATCH 19/29] clean-up, fix tests --- tests/kernels/test_moe.py | 6 +++ .../layers/fused_moe/fused_moe_awq.py | 1 + vllm/model_executor/layers/fused_moe/layer.py | 50 ++++++++++--------- .../model_executor/layers/quantization/awq.py | 25 ++++++---- vllm/model_executor/model_loader/loader.py | 1 + 5 files changed, 50 insertions(+), 33 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 3deb744a3d5d..70a710a93589 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -7,8 +7,11 @@ from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock +from vllm import _custom_ops as ops from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.quantization.awq import (AWQConfig, + AWQLinearMethod) from vllm.model_executor.models.mixtral import MixtralMoE @@ -100,6 +103,7 @@ def test_mixtral_moe(dtype: torch.dtype): rtol=mixtral_moe_tol[dtype], atol=mixtral_moe_tol[dtype]) + def torch_moe_awq(a, w1, w1_scale, w1_zero, w2, w2_scale, w2_zero, score, topk): score = torch.softmax(score.float(), dim=-1) @@ -176,5 +180,7 @@ def test_fused_moe_awq( awq_method = AWQLinearMethod(AWQConfig(4, groupsize, False)) torch_output = torch_moe_awq(a, qw1, scale1, zero1, qw2, scale2, zero2, score, topk) + # TODO @dsikka: what is this supposed to be applying? + # LinearMethod used for a fused test? cuda_output = awq_method.apply_moe_weights(w1, w2, a, score, topk, False) assert torch.allclose(cuda_output, torch_output, atol=1e-2, rtol=0) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_awq.py b/vllm/model_executor/layers/fused_moe/fused_moe_awq.py index a0e3ce33a248..3d618d9a0c11 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_awq.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_awq.py @@ -44,6 +44,7 @@ def fused_experts_awq( do_naive_dequant = hidden_states.shape[:-1].numel() >= NAIVE_THRESHOLD if do_naive_dequant: # TODO: why is this not contiguous already? + # from @dsikka: because of the permutation operation dequant_w1 = ops.awq_dequantize(w1, w1_scales, w1_qzeros, 0, 0, 0).permute(0, 2, 1).contiguous() dequant_w2 = ops.awq_dequantize(w2, w2_scales, w2_qzeros, 0, 0, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index adb00b0c0f5e..6f53888c31a0 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -7,10 +7,10 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, fused_topk, grouped_topk) -from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs @@ -62,18 +62,19 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor) -> torch.Tensor: - return self.forward(x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids) + return self.forward(x=x, + layer=layer, + topk_weights=topk_weights, + topk_ids=topk_ids) def forward_cuda( - self, layer: torch.nn.Module, x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, + self, + layer: torch.nn.Module, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> torch.Tensor: - return fused_experts(x=x, + return fused_experts(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, @@ -85,15 +86,17 @@ def forward_cpu(self, *args, **kwargs): "The CPU backend currently does not support MoE.") def forward_tpu( - self, layer: torch.nn.Module, x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, + self, + layer: torch.nn.Module, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> torch.Tensor: - + #assert not use_grouped_topk #assert num_expert_group is None #assert topk_group is None - return fused_experts(x=x, + return fused_experts(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, @@ -151,8 +154,6 @@ def __init__( self.reduce_results = reduce_results self.renormalize = renormalize self.use_grouped_topk = use_grouped_topk - if self.use_grouped_topk: - assert num_expert_group is not None and topk_group is not None self.num_expert_group = num_expert_group self.topk_group = topk_group @@ -268,8 +269,8 @@ def _select_experts(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): # DeekSeekv2 uses grouped_top_k if self.use_grouped_topk: - assert (self.num_expert_group is not None - and self.topk_group is not None) + assert self.topk_group is not None + assert self.num_expert_group is not None topk_weights, topk_ids = grouped_topk(hidden_states, router_logits, self.top_k, self.renormalize, self.num_expert_group, @@ -285,12 +286,15 @@ def forward(self, hidden_states: torch.Tensor, assert self.quant_method is not None # Select experts. - topk_weights, topk_ids = self._select_experts(hidden_states, - router_logits) + topk_weights, topk_ids = self._select_experts( + hidden_states=hidden_states, router_logits=router_logits) # Call fused kernel. - final_hidden_states = self.quant_method.apply(self, hidden_states, - topk_weights, topk_ids) + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids) # Optionally reduce. if self.reduce_results and self.tp_size > 1: diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index f480623c768a..1be9ec695343 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import torch from torch.nn.parameter import Parameter @@ -8,7 +8,7 @@ fused_experts_awq) from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig) from vllm.model_executor.utils import set_weight_attrs @@ -65,9 +65,9 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": zero_point = cls.get_from_keys(config, ["zero_point"]) return cls(weight_bits, group_size, zero_point) - - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["AWQLinearMethod"]: + def get_quant_method( + self, layer: torch.nn.Module, + prefix: str) -> Optional[Union["AWQMoEMethod", "AWQLinearMethod"]]: if isinstance(layer, LinearBase): return AWQLinearMethod(self) elif isinstance(layer, FusedMoE): @@ -283,8 +283,13 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor) -> torch.Tensor: - return fused_experts_awq(x, layer.w13_qweight, layer.w2_qweight, - layer.w13_scales, layer.w2_scales, - layer.w13_qzeros, layer.w2_qzeros, - topk_weights, topk_ids, - self.quant_config.pack_factor) + return fused_experts_awq(hidden_states=x, + w1=layer.w13_qweight, + w2=layer.w2_qweight, + w1_scales=layer.w13_scales, + w2_scales=layer.w2_scales, + w1_qzeros=layer.w13_qzeros, + w2_qzeros=layer.w2_qzeros, + topk_weights=topk_weights, + topk_ids=topk_ids, + pack_factor=self.quant_config.pack_factor) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index bbe49655020d..fecca801a1b6 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -405,6 +405,7 @@ def load_model(self, *, model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module: + self._verify_config(model_config, parallel_config) if parallel_config.tensor_parallel_size > 1: From 8d52ae543d45ad9db55ef3b8b5ce74f532772140 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 1 Aug 2024 15:35:13 +0000 Subject: [PATCH 20/29] normalize weights to prevent illegal memory --- vllm/model_executor/layers/fused_moe/fused_moe.py | 4 ++-- vllm/model_executor/layers/fused_moe/fused_moe_awq.py | 4 ++-- vllm/model_executor/models/deepseek_v2.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 413c0b6d0924..c9b4dad0a19e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -442,8 +442,8 @@ def fused_experts(hidden_states: torch.Tensor, assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" + #assert w1.is_contiguous(), "Expert weights1 must be contiguous" + #assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [ torch.float32, torch.float16, torch.bfloat16 ] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_awq.py b/vllm/model_executor/layers/fused_moe/fused_moe_awq.py index 3d618d9a0c11..4b2935727e1c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_awq.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_awq.py @@ -46,9 +46,9 @@ def fused_experts_awq( # TODO: why is this not contiguous already? # from @dsikka: because of the permutation operation dequant_w1 = ops.awq_dequantize(w1, w1_scales, w1_qzeros, 0, 0, - 0).permute(0, 2, 1).contiguous() + 0).permute(0, 2, 1) dequant_w2 = ops.awq_dequantize(w2, w2_scales, w2_qzeros, 0, 0, - 0).permute(0, 2, 1).contiguous() + 0).permute(0, 2, 1) return fused_experts(hidden_states, dequant_w1, dequant_w2, topk_weights, topk_ids) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 2e3e9b6f2792..cdff4cb123ae 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -115,7 +115,7 @@ def __init__( hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, reduce_results=False, - renormalize=config.norm_topk_prob, + renormalize=True, quant_config=quant_config, use_grouped_topk=True, num_expert_group=config.n_group, From c08a5da484e4999b3c5f59b1cad1ed38ce9f5fa7 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 1 Aug 2024 21:14:11 +0000 Subject: [PATCH 21/29] all MoE tests working --- tests/kernels/test_moe.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 70a710a93589..eb72445190e8 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -9,9 +9,9 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.quantization.awq import (AWQConfig, - AWQLinearMethod) +from vllm.model_executor.layers.fused_moe import (fused_experts_awq, fused_moe, + fused_topk) +from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.models.mixtral import MixtralMoE @@ -177,10 +177,19 @@ def test_fused_moe_awq( score = torch.randn((m, e), device='cuda', dtype=torch.half) - awq_method = AWQLinearMethod(AWQConfig(4, groupsize, False)) + quant_config = AWQConfig(4, groupsize, False) torch_output = torch_moe_awq(a, qw1, scale1, zero1, qw2, scale2, zero2, score, topk) - # TODO @dsikka: what is this supposed to be applying? - # LinearMethod used for a fused test? - cuda_output = awq_method.apply_moe_weights(w1, w2, a, score, topk, False) + + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + cuda_output = fused_experts_awq(hidden_states=a, + w1=w1["qweight"], + w2=w2["qweight"], + w1_scales=w1["scales"], + w2_scales=w2["scales"], + w1_qzeros=w1["qzeros"], + w2_qzeros=w2["qzeros"], + topk_weights=topk_weights, + topk_ids=topk_ids, + pack_factor=quant_config.pack_factor) assert torch.allclose(cuda_output, torch_output, atol=1e-2, rtol=0) From 7325e78a9cf24a62b32e8e4dd2b42d1fae626a84 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 2 Aug 2024 02:52:50 +0000 Subject: [PATCH 22/29] revert to reproduce error --- vllm/model_executor/models/deepseek_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index cdff4cb123ae..551ba595810d 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -115,7 +115,7 @@ def __init__( hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, reduce_results=False, - renormalize=True, + renormalize=False, quant_config=quant_config, use_grouped_topk=True, num_expert_group=config.n_group, From 0538dcc040e39928864db51453d40e9303252602 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 2 Aug 2024 20:45:00 +0000 Subject: [PATCH 23/29] update to comply with main --- vllm/model_executor/layers/fused_moe/layer.py | 40 ++++++++----------- .../model_executor/layers/quantization/awq.py | 4 +- .../model_executor/layers/quantization/fp8.py | 4 +- vllm/model_executor/models/deepseek_v2.py | 2 +- 4 files changed, 21 insertions(+), 29 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6f53888c31a0..31191b45382b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -28,8 +28,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, @abstractmethod def apply(self, layer: torch.nn.Module, x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor) -> torch.Tensor: + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + **kwargs) -> torch.Tensor: raise NotImplementedError @@ -59,21 +59,18 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, set_weight_attrs(w2_weight, extra_weight_attrs) def apply(self, layer: torch.nn.Module, x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor) -> torch.Tensor: + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + **kwargs) -> torch.Tensor: return self.forward(x=x, layer=layer, topk_weights=topk_weights, - topk_ids=topk_ids) + topk_ids=topk_ids, + **kwargs) - def forward_cuda( - self, - layer: torch.nn.Module, - x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - ) -> torch.Tensor: + def forward_cuda(self, layer: torch.nn.Module, x: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + **kwargs) -> torch.Tensor: return fused_experts(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -85,17 +82,11 @@ def forward_cpu(self, *args, **kwargs): raise NotImplementedError( "The CPU backend currently does not support MoE.") - def forward_tpu( - self, - layer: torch.nn.Module, - x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - ) -> torch.Tensor: - - #assert not use_grouped_topk - #assert num_expert_group is None - #assert topk_group is None + def forward_tpu(self, layer: torch.nn.Module, x: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + use_grouped_topk: bool) -> torch.Tensor: + + assert not use_grouped_topk return fused_experts(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -294,7 +285,8 @@ def forward(self, hidden_states: torch.Tensor, layer=self, x=hidden_states, topk_weights=topk_weights, - topk_ids=topk_ids) + topk_ids=topk_ids, + use_grouped_topk=self.use_grouped_topk) # Optionally reduce. if self.reduce_results and self.tp_size > 1: diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 1be9ec695343..2aff6fea50e7 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -280,8 +280,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, }) def apply(self, layer: torch.nn.Module, x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor) -> torch.Tensor: + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + **kwargs) -> torch.Tensor: return fused_experts_awq(hidden_states=x, w1=layer.w13_qweight, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 72666690ccc0..939960ba395a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -408,8 +408,8 @@ def process_weights_after_loading(self, layer: Module) -> None: return def apply(self, layer: torch.nn.Module, x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor) -> torch.Tensor: + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + **kwargs) -> torch.Tensor: return fused_experts(x, layer.w13_weight, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 551ba595810d..cdff4cb123ae 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -115,7 +115,7 @@ def __init__( hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, reduce_results=False, - renormalize=False, + renormalize=True, quant_config=quant_config, use_grouped_topk=True, num_expert_group=config.n_group, From 0ba00abd580b86b1240e3ee87730a08ff2b5ccfe Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Sun, 4 Aug 2024 21:21:05 +0000 Subject: [PATCH 24/29] PR comments --- .../layers/fused_moe/__init__.py | 2 +- .../layers/fused_moe/fused_moe.py | 4 +-- .../layers/fused_moe/fused_moe_awq.py | 11 ++++---- vllm/model_executor/layers/fused_moe/layer.py | 6 ++--- .../model_executor/layers/quantization/awq.py | 27 +++++++++++-------- 5 files changed, 27 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 9b90483644ba..f474308c432e 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -18,8 +18,8 @@ __all__ += [ "fused_moe", - "fused_topk", "fused_experts", + "fused_topk", "get_config_file_name", "grouped_topk", ] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index c9b4dad0a19e..413c0b6d0924 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -442,8 +442,8 @@ def fused_experts(hidden_states: torch.Tensor, assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - #assert w1.is_contiguous(), "Expert weights1 must be contiguous" - #assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [ torch.float32, torch.float16, torch.bfloat16 ] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_awq.py b/vllm/model_executor/layers/fused_moe/fused_moe_awq.py index 4b2935727e1c..c9ad7922e338 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_awq.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_awq.py @@ -3,8 +3,8 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger - -from .fused_moe import fused_experts, moe_align_block_size +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_experts, moe_align_block_size) logger = init_logger(__name__) @@ -43,12 +43,11 @@ def fused_experts_awq( # If large seq_len prefill, dequantize and use the fp16 MoE kernel. do_naive_dequant = hidden_states.shape[:-1].numel() >= NAIVE_THRESHOLD if do_naive_dequant: - # TODO: why is this not contiguous already? - # from @dsikka: because of the permutation operation + # NOTE: not contiguous because of the permutation operation dequant_w1 = ops.awq_dequantize(w1, w1_scales, w1_qzeros, 0, 0, - 0).permute(0, 2, 1) + 0).permute(0, 2, 1).contiguous() dequant_w2 = ops.awq_dequantize(w2, w2_scales, w2_qzeros, 0, 0, - 0).permute(0, 2, 1) + 0).permute(0, 2, 1).contiguous() return fused_experts(hidden_states, dequant_w1, dequant_w2, topk_weights, topk_ids) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 31191b45382b..2852cdddfd02 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -225,9 +225,9 @@ def weight_loader(self, param: torch.nn.Parameter, # If transposed, weight is saved as [input_dim, output_dim] # Otherwise, weight is saved as [output_dim, input_dim] - is_transposed = getattr(param, "is_transposed", False) - input_dim = 0 if is_transposed else 1 - output_dim = 1 if is_transposed else 0 + # Default is not transposed/input dim is dim 1 + input_dim = getattr(param, "input_dim", 1) + output_dim = getattr(param, "output_dim", 0) # Index the loaded weight for tp sharding. # down_proj: "RowParallel" so tp sharding on input_dim diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 2aff6fea50e7..c2c210092417 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter @@ -8,7 +8,7 @@ fused_experts_awq) from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs @@ -65,9 +65,8 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": zero_point = cls.get_from_keys(config, ["zero_point"]) return cls(weight_bits, group_size, zero_point) - def get_quant_method( - self, layer: torch.nn.Module, - prefix: str) -> Optional[Union["AWQMoEMethod", "AWQLinearMethod"]]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): return AWQLinearMethod(self) elif isinstance(layer, FusedMoE): @@ -202,7 +201,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, w13_qweight, { "packed_dim": 1, "pack_factor": self.quant_config.pack_factor, - "is_transposed": True, + "input_dim": 0, + "output_dim": 1, **extra_weight_attrs }) @@ -217,7 +217,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, w2_qweight, { "packed_dim": 1, "pack_factor": self.quant_config.pack_factor, - "is_transposed": True, + "input_dim": 0, + "output_dim": 1, **extra_weight_attrs }) @@ -231,7 +232,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, requires_grad=False) layer.register_parameter("w13_scales", w13_scales) set_weight_attrs(w13_scales, { - "is_transposed": True, + "input_dim": 0, + "output_dim": 1, **extra_weight_attrs }) @@ -243,7 +245,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, requires_grad=False) layer.register_parameter("w2_scales", w2_scales) set_weight_attrs(w2_scales, { - "is_transposed": True, + "input_dim": 0, + "output_dim": 1, **extra_weight_attrs }) @@ -260,7 +263,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, w13_qzeros, { "packed_dim": 1, "pack_factor": self.quant_config.pack_factor, - "is_transposed": True, + "input_dim": 0, + "output_dim": 1, **extra_weight_attrs }) @@ -275,7 +279,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, w2_qzeros, { "packed_dim": 1, "pack_factor": self.quant_config.pack_factor, - "is_transposed": True, + "input_dim": 0, + "output_dim": 1, **extra_weight_attrs }) From 419eb7d378551f31d11b0ee0fa57e685ac56c8e5 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 5 Aug 2024 18:59:44 +0000 Subject: [PATCH 25/29] fix tpu forward pass; use kwargs --- .../layers/fused_moe/fused_moe_awq.py | 26 ++++++++++++++----- vllm/model_executor/layers/fused_moe/layer.py | 24 ++++++++++------- vllm/model_executor/models/deepseek_v2.py | 2 +- 3 files changed, 36 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_awq.py b/vllm/model_executor/layers/fused_moe/fused_moe_awq.py index c9ad7922e338..f1b06993098a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_awq.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_awq.py @@ -57,17 +57,31 @@ def fused_experts_awq( x = hidden_states.view(hidden_states.shape[0], 1, *hidden_states.shape[1:]) - gate_up = ops.awq_fused_moe(x, w1, w1_scales, w1_qzeros, topk_weights, - sorted_token_ids, expert_ids, - num_tokens_post_padded, False, pack_factor) + gate_up = ops.awq_fused_moe(input=x, + qweight=w1, + scales=w1_scales, + qzeros=w1_qzeros, + topk_weights=topk_weights, + sorted_token_ids=sorted_token_ids, + expert_ids=expert_ids, + num_tokens_post_padded=num_tokens_post_padded, + mul_weights=False, + pack_factor=pack_factor) out = torch.empty((gate_up.shape[:-1] + (gate_up.shape[-1] // 2, )), dtype=hidden_states.dtype, device=hidden_states.device) ops.silu_and_mul(out, gate_up) - out = ops.awq_fused_moe(out, w2, w2_scales, w2_qzeros, topk_weights, - sorted_token_ids, expert_ids, - num_tokens_post_padded, True, pack_factor) + out = ops.awq_fused_moe(input=out, + qweight=w2, + scales=w2_scales, + qzeros=w2_qzeros, + topk_weights=topk_weights, + sorted_token_ids=sorted_token_ids, + expert_ids=expert_ids, + num_tokens_post_padded=num_tokens_post_padded, + mul_weights=True, + pack_factor=pack_factor) return torch.sum(out, dim=1) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 2852cdddfd02..976a8f5e8ad7 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -83,16 +83,18 @@ def forward_cpu(self, *args, **kwargs): "The CPU backend currently does not support MoE.") def forward_tpu(self, layer: torch.nn.Module, x: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - use_grouped_topk: bool) -> torch.Tensor: + use_grouped_topk: bool, topk: int, + gating_output: torch.Tensor, renormalize: bool, + **kwargs) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe assert not use_grouped_topk - return fused_experts(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True) + return fused_moe(hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk=topk, + gating_output=gating_output, + renormalize=renormalize) class FusedMoE(torch.nn.Module): @@ -254,7 +256,8 @@ def weight_loader(self, param: torch.nn.Parameter, elif shard_id == "w2": expert_data.copy_(loaded_weight) else: - raise ValueError + raise ValueError( + f"Expected shard_id w1,w2 or w3 but got {shard_id}") def _select_experts(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): @@ -286,6 +289,9 @@ def forward(self, hidden_states: torch.Tensor, x=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, + topk=self.top_k, + gating_output=router_logits, + renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk) # Optionally reduce. diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index cdff4cb123ae..2e3e9b6f2792 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -115,7 +115,7 @@ def __init__( hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, reduce_results=False, - renormalize=True, + renormalize=config.norm_topk_prob, quant_config=quant_config, use_grouped_topk=True, num_expert_group=config.n_group, From 5666fcbc6c3e12952ccae4b63dc02a0a63c7a97f Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 6 Aug 2024 20:05:33 +0000 Subject: [PATCH 26/29] fix triton import --- .../layers/fused_moe/__init__.py | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index f474308c432e..0f8060066e26 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,22 +1,18 @@ -from vllm.model_executor.layers.fused_moe.fused_moe_awq import ( - fused_experts_awq) -from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, - FusedMoEMethodBase) from vllm.triton_utils import HAS_TRITON -__all__ = [ - "fused_experts_awq", - "FusedMoE", - "FusedMoEMethodBase", -] - if HAS_TRITON: - from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) + from vllm.model_executor.layers.fused_moe.fused_moe_awq import ( + fused_experts_awq) + from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, + FusedMoEMethodBase) - __all__ += [ + __all__ = [ + "FusedMoE", + "FusedMoEMethodBase", + "fused_experts_awq", "fused_moe", "fused_experts", "fused_topk", From 8013ad4d57f49ce79a9c64bd87e51861fdbcc082 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 6 Aug 2024 20:18:52 +0000 Subject: [PATCH 27/29] further fix imports --- vllm/model_executor/layers/fused_moe/__init__.py | 13 ++++++++----- vllm/model_executor/layers/fused_moe/layer.py | 8 +++++--- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 0f8060066e26..e8efc82d621b 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,17 +1,20 @@ +from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, + FusedMoEMethodBase) from vllm.triton_utils import HAS_TRITON +__all__ = [ + "FusedMoE", + "FusedMoEMethodBase", +] + if HAS_TRITON: from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) from vllm.model_executor.layers.fused_moe.fused_moe_awq import ( fused_experts_awq) - from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, - FusedMoEMethodBase) - __all__ = [ - "FusedMoE", - "FusedMoEMethodBase", + __all__ += [ "fused_experts_awq", "fused_moe", "fused_experts", diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 976a8f5e8ad7..da3db6be1318 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -8,9 +8,6 @@ tensor_model_parallel_all_reduce) from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, - fused_topk, - grouped_topk) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs @@ -71,6 +68,8 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor, def forward_cuda(self, layer: torch.nn.Module, x: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, **kwargs) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_experts) return fused_experts(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -261,6 +260,9 @@ def weight_loader(self, param: torch.nn.Parameter, def _select_experts(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, grouped_topk) + # DeekSeekv2 uses grouped_top_k if self.use_grouped_topk: assert self.topk_group is not None From be34dc066b11adf1e9771f64c36d84f88b112065 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 6 Aug 2024 20:32:53 +0000 Subject: [PATCH 28/29] fix --- vllm/model_executor/layers/quantization/awq.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index c2c210092417..3ad391ed2403 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -4,8 +4,7 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, - fused_experts_awq) +from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) @@ -287,7 +286,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, def apply(self, layer: torch.nn.Module, x: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, **kwargs) -> torch.Tensor: - + from vllm.model_executor.layers.fused_moe.fused_moe_awq import ( + fused_experts_awq) return fused_experts_awq(hidden_states=x, w1=layer.w13_qweight, w2=layer.w2_qweight, From 6e7bbf975a09fe5f82176f71bce7b5cf0f5e1a2e Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 6 Aug 2024 20:43:04 +0000 Subject: [PATCH 29/29] fix fp8 --- vllm/model_executor/layers/quantization/fp8.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 939960ba395a..a32e421932c0 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -6,8 +6,7 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, - fused_experts) +from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( @@ -410,7 +409,7 @@ def process_weights_after_loading(self, layer: Module) -> None: def apply(self, layer: torch.nn.Module, x: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, **kwargs) -> torch.Tensor: - + from vllm.model_executor.layers.fused_moe import fused_experts return fused_experts(x, layer.w13_weight, layer.w2_weight,