diff --git a/csrc/ops.h b/csrc/ops.h index d5d6e240da7c..30529a5469e6 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -99,13 +99,25 @@ torch::Tensor awq_dequantize( int thx, int thy); +torch::Tensor awq_group_gemm( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + torch::Tensor _topk_weights, + torch::Tensor _sorted_token_ids_ptr, + torch::Tensor _expert_ids_ptr, + torch::Tensor _num_tokens_post_padded, + bool mul_weights, + int split_k_iters); + torch::Tensor marlin_gemm( - torch::Tensor& a, + torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, + torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t size_m, - int64_t size_n, + int64_t size_m, + int64_t size_n, int64_t size_k); #endif @@ -129,6 +141,29 @@ void gptq_shuffle( torch::Tensor q_perm, int bit); +torch::Tensor group_gptq_gemm( + torch::Tensor a, + torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, + torch::Tensor b_g_idx, + torch::Tensor topk_weights, + torch::Tensor sorted_token_ids_ptr, + torch::Tensor expert_ids_ptr, + torch::Tensor num_tokens_post_padded, + bool mul_weights, + bool use_exllama +); + +torch::Tensor dequant_gptq( + torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, + torch::Tensor b_g_idx, + int bits, + bool use_exllama +); + void moe_align_block_size( torch::Tensor topk_ids, int num_experts, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index a5c6439fd690..79fdd08fddea 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -64,12 +64,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Quantization ops #ifndef USE_ROCM ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); + ops.def("awq_group_gemm", &awq_group_gemm, "Grouped Quantized GEMM for AWQ"); ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); #endif - + ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); + ops.def("group_gptq_gemm", &group_gptq_gemm, "Grouped Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); + ops.def("dequant_gptq", &dequant_gptq, "Dequantize gptq weight to half"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); ops.def( "moe_align_block_size", diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 5aefb0bd16ae..f721af6b20b0 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -276,9 +276,17 @@ __global__ void __launch_bounds__(64) dequantize_weights( half* __restrict__ scaling_factors, int* __restrict__ zeros, half* __restrict__ C, - int G + int G, + int in_c, + int out_c ) { + if (blockIdx.z > 0) { + B = B + blockIdx.z * in_c * out_c / 8; + scaling_factors = scaling_factors + blockIdx.z * in_c * out_c / G; + zeros = zeros + blockIdx.z * in_c * out_c / G / 8; + C = C + blockIdx.z * in_c * out_c; + } int j_factors1 = 4; int row_stride2 = 4; int split_k_iters = 1; @@ -326,6 +334,251 @@ __global__ void __launch_bounds__(64) dequantize_weights( } } +template +__global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32( + int G, + int split_k_iters, + half* __restrict__ A, + int* __restrict__ B, + half* __restrict__ scaling_factors, + int* __restrict__ zeros, + const float* __restrict__ topk_weights, + const int* __restrict__ sorted_token_ids_ptr, + const int* __restrict__ expert_ids_ptr, + const int* __restrict__ num_tokens_post_padded, + const int num_valid_tokens, + const int top_k, + const int expert_num, + int pad_M, + int M, + int IC, + int OC, + half* __restrict__ C) +{ + // Only support matrix n = 64 or 128 + assert(N == 64 || N == 128); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + assert(false); +#else + int num_tokens = *num_tokens_post_padded; + int j_factors1 = ((OC + N - 1) / N); + int blockIdx_x = 0; + int blockIdx_y = blockIdx.x % ((pad_M + 16 - 1) / 16 * j_factors1); + int blockIdx_z = blockIdx.x / ((pad_M + 16 - 1) / 16 * j_factors1); + int block = blockIdx_y / j_factors1; + if (block * 16 >= num_tokens) return; + + static constexpr uint32_t ZERO = 0x0; + float C_warp[32]; + __shared__ half A_shared[16 * (32 + 8)]; + __shared__ half B_shared[32 * (N + 8)]; + + __shared__ half scaling_factors_shared[N]; + __shared__ half zeros_shared[N]; + + half A_shared_warp[8]; + half B_shared_warp[N / 4]; + for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) { + for (int i = 0; i < 8; ++i) { + C_warp[(j_0_4_init * 8) + i] = 0.0; + } + } + + static constexpr int row_stride_warp = 32 * 8 / 32; + static constexpr int row_stride = 2 * 32 * 8 / N; + bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N; + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 + + int row = (block * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32); + int token_id = sorted_token_ids_ptr[row]; + bool ld_A_flag = (token_id < num_valid_tokens); + half* A_ptr = A + token_id / top_k * IC + (((int)threadIdx.x) % (32 / 8)) * 8; + + int expert_id = expert_ids_ptr[block]; + B = B + OC * IC / 8 * expert_id; + scaling_factors = scaling_factors + OC * IC / G * expert_id; + zeros = zeros + OC * IC / G / 8 * expert_id; + + int* B_ptr = B + + ((int)threadIdx.y) * (OC / 8) * (256 / N) + + (((int)threadIdx.x) / (N / 8)) * (OC / 8) + + (((int)blockIdx_y) % j_factors1) * (N / 8) + + (((int)threadIdx.x) % (N / 8)) * 1; + // Why * 1 in the above line? + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + + (((int)threadIdx.x) % (32 / 8) ) * 8; + + half* B_shared_ptr = B_shared + + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) + + (((int)threadIdx.x) / (N / 8)) * (N + 8) + + (((int)threadIdx.x) % (N / 8)) * 8; + + int* zeros_ptr = zeros + + (((int)blockIdx_y) % j_factors1) * (N / 8) + + ((int)threadIdx.x) % (N / 8); + + half* scaling_factors_ptr = scaling_factors + + (((int)blockIdx_y) % j_factors1) * N + + (((int)threadIdx.x) % (N / 8)) * 8; + + half* C_ptr = C + + static_cast(blockIdx_z) * M * OC * expert_num // blockIdz.x -> split_k dim + + (((int)blockIdx_y) % j_factors1) * N + + ((int)threadIdx.y) * (N / 2) + + (((int)threadIdx.x) % 4) * 2; + + // preload s.f. and zeros + int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; + if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1; + for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { + int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; + __syncthreads(); + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 + if (ld_A_flag) + { + *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); + } + else + { + *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); + } + + uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); + uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); + + int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); + + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { + + uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); + + // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + + // write back + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16; + } + __syncthreads(); + + for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) { + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + + + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) + : "r"(addr) + ); + } + + for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr) + ); + } + } + for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } +#else + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } + +#endif + } + } + } + +// TODO: Shang: Hoist loop invariance. + for (int ax1_0_1 = 0; ax1_0_1 < N / 32; ++ax1_0_1) { + for (int local_id = 0; local_id < 8; ++local_id) { + int row_offset = block * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; + int token_id = sorted_token_ids_ptr[row_offset]; + if (token_id < num_valid_tokens) + { + float value = C_warp[(ax1_0_1 * 8) + local_id]; + if (topk_weights) { + value = value * topk_weights[token_id]; + } + *(C_ptr + ax1_0_1 * 16 + token_id * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(value); + } + } + } +#endif +} + } // namespace awq } // namespace vllm @@ -337,10 +590,11 @@ torch::Tensor awq_dequantize( int thx, int thy) { - int in_c = _kernel.size(0); - int qout_c = _kernel.size(1); + int in_c = _kernel.dim() == 2 ? _kernel.size(0) : _kernel.size(1); + int qout_c = _kernel.dim() == 2 ? _kernel.size(1) : _kernel.size(2); + int num_experts = _kernel.dim() == 2 ? 1 : _kernel.size(0); int out_c = qout_c * 8; - int G = in_c / _scaling_factors.size(0); + int G = in_c / (_kernel.dim() == 2 ? _scaling_factors.size(0) : _scaling_factors.size(1)); int x_thread = thx; int y_thread = thy; @@ -363,19 +617,24 @@ torch::Tensor awq_dequantize( const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); - at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); + at::Tensor _de_kernel; + if (num_experts == 1) { + _de_kernel = torch::empty({in_c, out_c}, options); + } else { + _de_kernel = torch::empty({num_experts, in_c, out_c}, options); + } auto kernel = reinterpret_cast(_kernel.data_ptr()); auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); auto zeros = reinterpret_cast(_zeros.data_ptr()); - dim3 num_blocks(x_blocks, y_blocks); + dim3 num_blocks(x_blocks, y_blocks, num_experts); dim3 threads_per_block(x_thread, y_thread); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); vllm::awq::dequantize_weights<<>>( - kernel, scaling_factors, zeros, de_kernel, G); + kernel, scaling_factors, zeros, de_kernel, G, in_c, out_c); return _de_kernel; } @@ -444,3 +703,69 @@ torch::Tensor awq_gemm( } return _out_feats.sum(0); } + +torch::Tensor awq_group_gemm( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + torch::Tensor _topk_weights, + torch::Tensor _sorted_token_ids_ptr, + torch::Tensor _expert_ids_ptr, + torch::Tensor _num_tokens_post_padded, + bool mul_weights, + int split_k_iters) +{ + int num_in_feats = _in_feats.size(0); + int pad_num_in_feats = _sorted_token_ids_ptr.size(0); + int num_in_channels = _in_feats.size(2); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); + + auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); + int num_experts = _topk_weights.size(1); + int top_k = num_experts / _in_feats.size(1); + int group_size = num_in_channels / _scaling_factors.size(1); + + at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _topk_weights.size(1), _kernel.size(2) * 8}, options); + int num_out_channels = _out_feats.size(-1); + + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto topk_weights = mul_weights ? reinterpret_cast(_topk_weights.data_ptr()) : nullptr; + auto sorted_token_ids_ptr = reinterpret_cast(_sorted_token_ids_ptr.data_ptr()); + auto expert_ids_ptr = reinterpret_cast(_expert_ids_ptr.data_ptr()); + auto num_tokens_post_padded = reinterpret_cast(_num_tokens_post_padded.data_ptr()); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (num_out_channels % 128 == 0) + { + int j_factors1 = num_out_channels / 128 / 1; + dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::group_gemm_forward_4bit_cuda_m16nXk32<128><<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, + topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded, + _topk_weights.numel(), top_k, num_experts, pad_num_in_feats, + num_in_feats, num_in_channels, num_out_channels, out_feats); + } + else if (num_out_channels % 64 == 0) + { + int j_factors1 = num_out_channels / 64 / 1; + dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::group_gemm_forward_4bit_cuda_m16nXk32<64><<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, + topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded, + _topk_weights.numel(), top_k, num_experts, pad_num_in_feats, + num_in_feats, num_in_channels, num_out_channels, out_feats); + } + return _out_feats.sum(0); +} diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 655158e38f55..8970bdd3cd13 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -921,6 +921,13 @@ __global__ void reconstruct_exllama_4bit_kernel half* __restrict__ b ) { + if (blockIdx.z > 0){ + b_q_weight = b_q_weight + blockIdx.z * size_k * size_n / 8; + b_gptq_scales = b_gptq_scales + blockIdx.z * groups * size_n; + b_gptq_qzeros = b_gptq_qzeros + blockIdx.z * groups * size_n / 8; + if (b_q_perm) b_q_perm = b_q_perm + blockIdx.z * size_k; + b = b + blockIdx.z * size_k * size_n; + } MatrixView_half_rw b_(b, size_k, size_n); MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); @@ -1235,6 +1242,7 @@ void reconstruct_exllama int height, int width, int groups, + int num_experts, int bit ) { @@ -1243,6 +1251,7 @@ void reconstruct_exllama blockDim.y = 1; gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + gridDim.z = num_experts; auto reconstruct_exllama_kernel = reconstruct_exllama_4bit_kernel; if (bit == 2) { @@ -1510,6 +1519,13 @@ __global__ void reconstruct_gptq_kernel half* __restrict__ out ) { + if (blockIdx.z > 0){ + w = w + blockIdx.z * height * width / 8; + w_scales = w_scales + blockIdx.z * group * width; + w_zeros = w_zeros + blockIdx.z * group * width / 8; + g_idx = g_idx + blockIdx.z * height; + out = out + blockIdx.z * height * width; + } // Start of block int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; @@ -1597,6 +1613,7 @@ void reconstruct_gptq int height, int width, int groups, + int num_experts, int bit ) { @@ -1605,6 +1622,7 @@ void reconstruct_gptq blockDim.y = 1; gridDim.y = DIVIDE(height, 32 / bit); gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + gridDim.z = num_experts; auto kernel = reconstruct_gptq_kernel; if (bit == 2) { @@ -1631,6 +1649,33 @@ void reconstruct_gptq } +void dequant_gptq_cuda +( + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_g_idx, + half* temp_dq, + int size_k, + int size_n, + int groups, + int num_experts, + int bits, + bool use_exllama +) +{ + if (use_exllama) { + reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq, + size_k, size_n, groups, num_experts, bits); + } + else + { + reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + temp_dq, size_k, size_n, groups, num_experts, bits); + } +} + + void gemm_half_q_half_cuda ( cublasHandle_t cublas_handle, @@ -1658,15 +1703,8 @@ void gemm_half_q_half_cuda } if (use_reconstruct) { // Reconstruct FP16 matrix, then cuBLAS - if (use_exllama) { - reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq, - size_k, size_n, groups, bit); - } - else - { - reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - temp_dq, size_k, size_n, groups, bit); - } + dequant_gptq_cuda(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq, + size_k, size_n, groups, 1, bit, use_exllama); const half alpha = __float2half(1.0f); const half beta = __float2half(0.0f); @@ -1768,9 +1806,16 @@ __global__ void make_sequential_4bit_kernel const uint32_t* __restrict__ w, uint32_t* __restrict__ w_new, const int* __restrict__ q_perm, + const int w_height, const int w_width ) { + if (blockIdx.z > 0){ + w = w + blockIdx.z * w_height * w_width; + w_new = w_new + blockIdx.z * w_height * w_width; + q_perm = q_perm + blockIdx.z * w_height * 8; + } + const uint64_t* w2 = (uint64_t*) w; uint64_t* w_new2 = (uint64_t*) w_new; int w2_stride = w_width >> 1; @@ -1804,9 +1849,15 @@ __global__ void make_sequential_2bit_kernel const uint32_t* __restrict__ w, uint32_t* __restrict__ w_new, const int* __restrict__ q_perm, + const int w_height, const int w_width ) { + if (blockIdx.z > 0){ + w = w + blockIdx.z * w_height * w_width; + w_new = w_new + blockIdx.z * w_height * w_width; + q_perm = q_perm + blockIdx.z * w_height * 16; + } const uint64_t* w2 = (uint64_t*) w; uint64_t* w_new2 = (uint64_t*) w_new; int w2_stride = w_width >> 1; @@ -1840,9 +1891,15 @@ __global__ void make_sequential_3bit_kernel const uint32_t* __restrict__ w, uint32_t* __restrict__ w_new, const int* __restrict__ q_perm, + const int w_height, const int w_width ) { + if (blockIdx.z > 0){ + w = w + blockIdx.z * w_height * w_width; + w_new = w_new + blockIdx.z * w_height * w_width; + q_perm = q_perm + blockIdx.z * w_height * 32 / 3; + } int w_column = THREADS_X * blockIdx.x + threadIdx.x; if (w_column >= w_width) return; int w_new_row = blockIdx.y * 3; @@ -1926,9 +1983,15 @@ __global__ void make_sequential_8bit_kernel const uint32_t* __restrict__ w, uint32_t* __restrict__ w_new, const int* __restrict__ q_perm, + const int w_height, const int w_width ) { + if (blockIdx.z > 0){ + w = w + blockIdx.z * w_height * w_width; + w_new = w_new + blockIdx.z * w_height * w_width; + q_perm = q_perm + blockIdx.z * w_height * 4; + } const uint64_t* w2 = (uint64_t*) w; uint64_t* w_new2 = (uint64_t*) w_new; int w2_stride = w_width >> 1; @@ -1964,19 +2027,21 @@ void shuffle_exllama_weight int* q_perm, int height, int width, + int num_experts, int bit ) { if (q_perm) { uint32_t* new_qweight = NULL; - cudaMalloc(&new_qweight, height / 32 * bit * width * sizeof(uint32_t)); + cudaMalloc(&new_qweight, num_experts * height / 32 * bit * width * sizeof(uint32_t)); dim3 blockDim, gridDim; blockDim.x = THREADS_X; blockDim.y = 1; gridDim.x = DIVIDE(width, THREADS_X); gridDim.y = height / 32 * bit; + gridDim.z = num_experts; auto kernel = make_sequential_4bit_kernel; if (bit == 2) { @@ -1993,10 +2058,11 @@ void shuffle_exllama_weight q_weight, new_qweight, q_perm, + height / 32 * bit, width ); // Replace qweights - cudaMemcpyAsync(q_weight, new_qweight, height / 32 * bit * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); + cudaMemcpyAsync(q_weight, new_qweight, num_experts * height / 32 * bit * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); // Cleanup cudaDeviceSynchronize(); cudaFree(new_qweight); @@ -2015,7 +2081,438 @@ void shuffle_exllama_weight shuffle_kernel = shuffle_8bit_kernel; } const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - shuffle_kernel<<>>(q_weight, height, width); + shuffle_kernel<<>>(q_weight, height * num_experts, width); +} + + +template +__global__ void group_gemm_half_q_half_gptq_kernel +( + const half* __restrict__ a, + const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + half* __restrict__ c, + const int size_m, + const int size_n, + const int size_k, + const int groups, + const int* __restrict__ b_q_perm, + const float* __restrict__ topk_weights, + const int* __restrict__ sorted_token_ids_ptr, + const int* __restrict__ expert_ids_ptr, + const int* __restrict__ num_tokens_post_padded, + const int num_valid_tokens, + const int top_k +) +{ + int num_tokens = *num_tokens_post_padded; + int offset_m = blockIdx.y * m_count; + if (offset_m >= num_tokens) return; + + int expert_id = expert_ids_ptr[blockIdx.y]; + b_q_weight = b_q_weight + size_k * size_n / 8 * expert_id; + b_gptq_qzeros = b_gptq_qzeros + groups * size_n / 8 * expert_id; + b_gptq_scales = b_gptq_scales + groups * size_n * expert_id; + if (b_q_perm) b_q_perm = b_q_perm + size_k * expert_id; + + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + int token_a[m_count]; + + int valid_count = m_count; + for (int m = 0; m < m_count; ++m) { + int token_id = sorted_token_ids_ptr[offset_m + m]; + if (token_id >= num_valid_tokens) { + valid_count = m; + break; + } + token_a[m] = token_id; + } + + if (offset_k + t < end_k) + { + for (int m = 0; m < valid_count; ++m) + { + const half* a_ptr = a_.item_ptr(token_a[m] / top_k, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; + else a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) return; + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + float scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + // Column result + float block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) + { + if (k == nextgroup) + { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + } + + #pragma unroll + for (int j = 0; j < 4; j++) + { + const int4* b_ptr4 = (int4*) b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][4]; + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); + + for (int m = 0; m < valid_count; m++) + { + block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); + block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); + block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); + block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]); + } + + b_ptr += size_n; + a_ptr += 8; + } + + k += 32; + } + + for (int m = 0; m < valid_count; m++) + { + if (topk_weights) { + #pragma unroll + for (int j = 0; j < 4; ++j) { + block_c[m][j] = block_c[m][j] * topk_weights[token_a[m]]; + } + } + half2 *out = (half2*) c_.item_ptr(token_a[m], n); + half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); + half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3])); + atomicAdd(out , result01); + atomicAdd(out + 1, result23); + } +} + +void group_gemm_half_q_half +( + const half* a, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_q_perm, + half* c, + const float* __restrict__ topk_weights, + const int* __restrict__ sorted_token_ids_ptr, + const int* __restrict__ expert_ids_ptr, + const int* __restrict__ num_tokens_post_padded, + const int num_valid_tokens, + const int top_k, + int size_m, + int size_n, + int size_k, + int pad_size_m, + int groups +) +{ + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); + gridDim.y = DIVIDE(pad_size_m, BLOCK_M_SIZE_MAX); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + group_gemm_half_q_half_gptq_kernel<<>> + ( + a, + b_q_weight, + b_gptq_qzeros, + b_gptq_scales, + c, + size_m, + size_n, + size_k, + groups, + b_q_perm, + topk_weights, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded, + num_valid_tokens, + top_k + ); +} + +__global__ void group_gemm_half_q_half_alt_kernel( + const half2* __restrict__ vec, + const uint32_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const uint32_t* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int height, + int width, + int groups, + const float* __restrict__ topk_weights, + const int* __restrict__ sorted_token_ids_ptr, + const int* __restrict__ expert_ids_ptr, + const int* __restrict__ num_tokens_post_padded, + const int num_valid_tokens, + const int top_k +) +{ + int num_tokens = *num_tokens_post_padded; + int b = blockIdx.y * BLOCK_M_SIZE_MAX; + if (b >= num_tokens) return; + + int expert_id = expert_ids_ptr[blockIdx.y]; + mat = mat + height * width * expert_id; + scales = scales + groups * width * expert_id; + zeros = zeros + groups * width / 8 * expert_id; + g_idx = g_idx + height * 8 * expert_id; + + int zero_width = width / 8; + int vec_height = height * 4; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + int b_end = BLOCK_M_SIZE_MAX; + int h = BLOCK_KN_SIZE * blockIdx.z / 8; + int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4; + int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + int token_a[BLOCK_M_SIZE_MAX]; + for (int m = 0; m < b_end; ++m) { + int token_id = sorted_token_ids_ptr[b + m]; + if (token_id >= num_valid_tokens) { + b_end = m; + break; + } + token_a[m] = token_id; + } + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = + vec[token_a[m] / top_k * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + + threadIdx.x]; + } + } + + __shared__ half2 deq2[256][8]; + int val = threadIdx.x / 8; + int off = threadIdx.x % 8; + for (; val < 256; val += BLOCK_KN_SIZE / 8) { + deq2[val][off] = __halves2half2( + __int2half_rn(val & 0xF), __int2half_rn(val >> 4) + ); + } + + __syncthreads(); + + int i = width * h + w; + int g_h = h * 8; + int k = 0; + int z_w = w / 8; + int z_mod = (w % 8) * 4; + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[4]; + half2 zeros_tmp[4]; + for (int tmp_k = 0; tmp_k < 4; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - 1)), + __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1)) + ); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; + } + for (int m = 0; m < b_end; m++) { +#ifndef USE_ROCM + res2 = {}; +#else + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); +#endif + res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2); +#ifndef USE_ROCM + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); +#else + res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); +#endif + } + i += width; + k += 4; + } + for (int m = 0; m < b_end; m++) { + if (topk_weights) { + res[m] = __float2half(__half2float(res[m]) * topk_weights[token_a[m]]); + } + atomicAdd(&mul[token_a[m] * width + w], res[m]); + } +} + + +void group_gemm_half_q_half_alt +( + const half* a, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_g_idx, + half* c, + const float* __restrict__ topk_weights, + const int* __restrict__ sorted_token_ids_ptr, + const int* __restrict__ expert_ids_ptr, + const int* __restrict__ num_tokens_post_padded, + const int num_valid_tokens, + const int top_k, + int size_m, + int size_n, + int size_k, + int pad_size_m, + int groups +) +{ + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE); + gridDim.y = DIVIDE(pad_size_m, BLOCK_M_SIZE_MAX); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + group_gemm_half_q_half_alt_kernel<<>> + ( + (const half2*) a, + b_q_weight, + c, + b_gptq_scales, + b_gptq_qzeros, + b_g_idx, + size_m, + size_k / 8, + size_n, + groups, + topk_weights, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded, + num_valid_tokens, + top_k + ); +} + +// Only support 4-bit so far +void group_gemm_half_q_half_cuda +( + const half* a, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_g_idx, + half* c, + const float* __restrict__ topk_weights, + const int* __restrict__ sorted_token_ids_ptr, + const int* __restrict__ expert_ids_ptr, + const int* __restrict__ num_tokens_post_padded, + const int num_valid_tokens, + const int top_k, + int size_m, + int size_n, + int size_k, + int pad_size_m, + int groups, + bool use_exllama +) { + if (use_exllama) { + group_gemm_half_q_half( + a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, c, + topk_weights, sorted_token_ids_ptr, expert_ids_ptr, + num_tokens_post_padded, num_valid_tokens, + top_k, size_m, size_n, size_k, pad_size_m, groups + ); + } else { + group_gemm_half_q_half_alt( + a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, c, + topk_weights, sorted_token_ids_ptr, expert_ids_ptr, + num_tokens_post_padded, num_valid_tokens, + top_k, size_m, size_n, size_k, pad_size_m, groups + ); + } } } // namespace gptq @@ -2065,11 +2562,107 @@ void gptq_shuffle ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); + + int num_experts = q_weight.dim() == 3 ? q_weight.size(0) : 1; + int size_k = q_weight.dim() == 3 ? q_weight.size(1) * 32 / bit : q_weight.size(0) * 32 / bit; + int size_n = q_weight.dim() == 3 ? q_weight.size(2) : q_weight.size(1); + vllm::gptq::shuffle_exllama_weight( (uint32_t*) q_weight.data_ptr(), q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(), - q_weight.size(0) * 32 / bit, - q_weight.size(1), + size_k, + size_n, + num_experts, bit ); } + +// Only support 4-bit +// todo: extend support to other bits +torch::Tensor group_gptq_gemm +( + torch::Tensor a, + torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, + torch::Tensor b_g_idx, + torch::Tensor topk_weights, + torch::Tensor sorted_token_ids_ptr, + torch::Tensor expert_ids_ptr, + torch::Tensor num_tokens_post_padded, + bool mul_weights, + bool use_exllama +) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + at::Tensor c = torch::zeros({a.size(0), topk_weights.size(1), b_q_weight.size(2)}, options); + + vllm::gptq::group_gemm_half_q_half_cuda + ( + (const half*) a.data_ptr(), + (const uint32_t*) b_q_weight.data_ptr(), + (const uint32_t*)b_gptq_qzeros.data_ptr(), + (const half*) b_gptq_scales.data_ptr(), + b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(), + (half*) c.data_ptr(), + mul_weights ? (const float*) topk_weights.data_ptr() : NULL, + (const int*) sorted_token_ids_ptr.data_ptr(), + (const int*) expert_ids_ptr.data_ptr(), + (const int*) num_tokens_post_padded.data_ptr(), + topk_weights.numel(), // num tokens + topk_weights.size(1) / a.size(1), // top_k + a.size(0) * a.size(1), // m + c.size(2), // n + a.size(2), // k + sorted_token_ids_ptr.size(0), + b_gptq_qzeros.size(1), // group number + use_exllama + ); + return c; +} + + +torch::Tensor dequant_gptq +( + torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, + torch::Tensor b_g_idx, + int bits, + bool use_exllama +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(b_gptq_scales)); + auto options = torch::TensorOptions().dtype(b_gptq_scales.dtype()).device(b_gptq_scales.device()); + + at::Tensor temp_dq; + int num_experts; + int size_k; + int size_n; + int groups; + // moe + if (b_q_weight.dim() == 3) { + temp_dq = torch::empty({b_q_weight.size(0), b_q_weight.size(1) * 32 / bits, b_q_weight.size(2)}, options); + num_experts = b_q_weight.size(0); + size_k = b_q_weight.size(1) * 32 / bits; + size_n = b_q_weight.size(2); + groups = b_gptq_scales.size(1); + } else + { + temp_dq = torch::empty({b_q_weight.size(0) * 32 / bits, b_q_weight.size(1)}, options); + num_experts = 1; + size_k = b_q_weight.size(0) * 32 / bits; + size_n = b_q_weight.size(1); + groups = b_gptq_scales.size(0); + } + vllm::gptq::dequant_gptq_cuda( + (const uint32_t*) b_q_weight.data_ptr(), + (const uint32_t*)b_gptq_qzeros.data_ptr(), + (const half*) b_gptq_scales.data_ptr(), + b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(), + (half*) temp_dq.data_ptr(), + size_k, size_n, groups, + num_experts, bits, use_exllama); + return temp_dq; +} diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index affbbfb4aa94..a9dfcf9a97fe 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -2,14 +2,24 @@ Run `pytest tests/kernels/test_moe.py`. """ +import tempfile + import pytest import torch from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock +from vllm._C import ops from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.quantization.awq import (AWQConfig, + AWQLinearMethod) +from vllm.model_executor.layers.quantization.gptq import (ExllamaState, + GPTQConfig, + GPTQLinearMethod) from vllm.model_executor.models.mixtral import MixtralMoE +from vllm.model_executor.parallel_utils.parallel_state import ( + destroy_model_parallel, initialize_model_parallel) def torch_moe(a, w1, w2, score, topk): @@ -48,7 +58,13 @@ def test_fused_moe( w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 score = torch.randn((m, e), device='cuda', dtype=dtype) - triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) + triton_output = fused_moe(a, + w1, + w2, + score, + topk, + renormalize=False, + inplace=False) torch_output = torch_moe(a, w1, w2, score, topk) assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0) @@ -59,6 +75,17 @@ def test_fused_moe( def test_mixtral_moe(dtype: torch.dtype): """Make sure our Mixtral MoE implementation agrees with the one from huggingface.""" + # Initialize dist environment + if not torch.distributed.is_initialized(): + temp_file = tempfile.mkstemp()[1] + torch.distributed.init_process_group( + backend="nccl", + world_size=1, + rank=0, + init_method=f"file://{temp_file}", + ) + initialize_model_parallel() + torch.set_default_dtype(dtype) # Instantiate our and huggingface's MoE blocks config = MixtralConfig() @@ -68,7 +95,6 @@ def test_mixtral_moe(dtype: torch.dtype): top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, - params_dtype=dtype, tp_size=1, ).cuda() @@ -77,8 +103,8 @@ def test_mixtral_moe(dtype: torch.dtype): for i in range(config.num_local_experts): weights = (hf_moe.experts[i].w1.weight.data, hf_moe.experts[i].w3.weight.data) - vllm_moe.ws[i][:] = torch.cat(weights, dim=0) - vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data + vllm_moe.ws.weight[i][:] = torch.cat(weights, dim=0) + vllm_moe.w2s.weight[i][:] = hf_moe.experts[i].w2.weight.data # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") @@ -89,6 +115,9 @@ def test_mixtral_moe(dtype: torch.dtype): hf_states, _ = hf_moe.forward(hf_inputs) vllm_states = vllm_moe.forward(vllm_inputs) + # destroy dist environment + destroy_model_parallel() + mixtral_moe_tol = { torch.float32: 1e-3, torch.float16: 1e-3, @@ -99,3 +128,177 @@ def test_mixtral_moe(dtype: torch.dtype): vllm_states, rtol=mixtral_moe_tol[dtype], atol=mixtral_moe_tol[dtype]) + + +def torch_moe_gptq(a, w1, w1_gidx, w1_scale, w1_zero, w2, w2_gidx, w2_scale, + w2_zero, score, topk): + score = torch.softmax(score.float(), dim=-1) + topk_weight, topk_ids = torch.topk(score, topk) + (B, D) = a.shape + a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) + out = torch.zeros(B * topk_ids.shape[1], + w2.shape[2], + dtype=a.dtype, + device=a.device) + topk_ids = topk_ids.view(-1) + topk_weight = topk_weight.view(-1) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + dw1 = ops.dequant_gptq(w1[i], w1_zero[i], w1_scale[i], w1_gidx[i], + 4, False) + dw2 = ops.dequant_gptq(w2[i], w2_zero[i], w2_scale[i], w2_gidx[i], + 4, False) + r1 = SiluAndMul()(torch.matmul(a[mask], dw1)) + out[mask] = torch.matmul(r1, dw2) + return (out.view(B, -1, w2.shape[2]) * + topk_weight.view(B, -1, 1)).sum(dim=1).half() + + +@pytest.mark.parametrize("m", [512, 222, 33, 1]) +@pytest.mark.parametrize("n", [2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 512, 1024]) +@pytest.mark.parametrize("e", [8]) +@pytest.mark.parametrize("topk", [2, 6]) +@pytest.mark.parametrize("exstate", + [ExllamaState.UNINITIALIZED, ExllamaState.UNUSED]) +@pytest.mark.parametrize("groupsize", [-1, 128]) +@pytest.mark.parametrize("actorder", [True, False]) +def test_fused_moe_gptq(m: int, n: int, k: int, e: int, topk: int, + exstate: ExllamaState, groupsize: int, actorder: bool): + RANGE = 1000000000 + a = torch.randn((m, k), device='cuda', dtype=torch.half) / 10 + qw1 = torch.randint(-RANGE, + RANGE, (e, (k // 32) * 4, n * 2), + dtype=torch.int, + device='cuda') + qw2 = torch.randint(-RANGE, + RANGE, (e, (n // 32) * 4, k), + dtype=torch.int, + device='cuda') + + groupsize1 = groupsize if groupsize != -1 else k + groupsize2 = groupsize if groupsize != -1 else n + gidx1 = torch.tensor([i // groupsize1 for i in range(k)], + dtype=torch.int32, + device='cuda').unsqueeze(0).expand(e, k).contiguous() + gidx2 = torch.tensor([i // groupsize2 for i in range(n)], + dtype=torch.int32, + device='cuda').unsqueeze(0).expand(e, n).contiguous() + + scale1 = torch.randn( + (e, k // groupsize1, n * 2), dtype=torch.half, device='cuda') / 50 + scale2 = torch.randn( + (e, n // groupsize2, k), dtype=torch.half, device='cuda') / 50 + + zero1 = torch.randint(-RANGE, + RANGE, (e, k // groupsize1, (n * 2 // 32) * 4), + dtype=torch.int32, + device='cuda') + zero2 = torch.randint(-RANGE, + RANGE, (e, n // groupsize2, (k // 32) * 4), + dtype=torch.int32, + device='cuda') + w1 = { + "qweight": qw1, + "g_idx": gidx1, + "scales": scale1, + "qzeros": zero1, + "exllama_state": exstate + } + w2 = { + "qweight": qw2, + "g_idx": gidx2, + "scales": scale2, + "qzeros": zero2, + "exllama_state": exstate + } + + score = torch.randn((m, e), device='cuda', dtype=torch.half) + + gptq_method = GPTQLinearMethod(GPTQConfig(4, groupsize, actorder)) + torch_output = torch_moe_gptq(a, qw1, gidx1, scale1, zero1, qw2, gidx2, + scale2, zero2, score, topk) + cuda_output = gptq_method.apply_moe_weights(w1, w2, a, score, topk, False) + # gptq kernels have large variance in output + assert torch.allclose(cuda_output, torch_output, atol=5e-2, rtol=0) + + +def torch_moe_awq(a, w1, w1_scale, w1_zero, w2, w2_scale, w2_zero, score, + topk): + score = torch.softmax(score.float(), dim=-1) + topk_weight, topk_ids = torch.topk(score, topk) + (B, D) = a.shape + a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) + out = torch.zeros(B * topk_ids.shape[1], + w2.shape[2] * 8, + dtype=a.dtype, + device=a.device) + topk_ids = topk_ids.view(-1) + topk_weight = topk_weight.view(-1) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + dw1 = ops.awq_dequantize(w1[i], w1_scale[i], w1_zero[i], 0, 0, 0) + dw2 = ops.awq_dequantize(w2[i], w2_scale[i], w2_zero[i], 0, 0, 0) + r1 = SiluAndMul()(torch.matmul(a[mask].half(), dw1)) + out[mask] = torch.matmul(r1, dw2).to(out.dtype) + return (out.view(B, -1, w2.shape[2] * 8) * + topk_weight.view(B, -1, 1)).sum(dim=1).half() + + +@pytest.mark.parametrize("m", [1024, 222, 33, 1]) +@pytest.mark.parametrize("n", [2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 512, 1024]) +@pytest.mark.parametrize("e", [8]) +@pytest.mark.parametrize("topk", [2, 6]) +def test_fused_moe_awq( + m: int, + n: int, + k: int, + e: int, + topk: int, +): + # awq requires minimum capability 75 + if torch.version.hip is not None: + return + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + if capability < 75: + return + + RANGE = 1000000000 + groupsize = 128 + a = torch.randn((m, k), device='cuda', dtype=torch.half) / 10 + qw1 = torch.randint(-RANGE, + RANGE, (e, k, n * 2 // 8), + dtype=torch.int, + device='cuda') + qw2 = torch.randint(-RANGE, + RANGE, (e, n, k // 8), + dtype=torch.int, + device='cuda') + + scale1 = torch.randn( + (e, k // groupsize, n * 2), dtype=torch.half, device='cuda') / 50 + scale2 = torch.randn( + (e, n // groupsize, k), dtype=torch.half, device='cuda') / 50 + + zero1 = torch.randint(-RANGE, + RANGE, (e, k // groupsize, (n * 2 // 32) * 4), + dtype=torch.int32, + device='cuda') + zero2 = torch.randint(-RANGE, + RANGE, (e, n // groupsize, (k // 32) * 4), + dtype=torch.int32, + device='cuda') + w1 = {"qweight": qw1, "scales": scale1, "qzeros": zero1} + w2 = {"qweight": qw2, "scales": scale2, "qzeros": zero2} + + score = torch.randn((m, e), device='cuda', dtype=torch.half) + + awq_method = AWQLinearMethod(AWQConfig(4, groupsize, False)) + torch_output = torch_moe_awq(a, qw1, scale1, zero1, qw2, scale2, zero2, + score, topk) + cuda_output = awq_method.apply_moe_weights(w1, w2, a, score, topk, False) + assert torch.allclose(cuda_output, torch_output, atol=1e-2, rtol=0) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 496d69c89c62..c399facdcf23 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,7 +1,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_moe, get_config_file_name) + fused_moe, fused_topk, get_config_file_name, moe_align_block_size) __all__ = [ - "fused_moe", - "get_config_file_name", + "fused_moe", "moe_align_block_size", "fused_topk", "get_config_file_name" ] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1ec09f0cd4c2..a0d103fe5583 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -245,6 +245,53 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, ) +def fused_topk( + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +): + """Compute top-k indice and weights from gating logits + + Args: + gating_output (torch.Tensor): The output of the gating operation + (before softmax). + topk (int): The number of top-k experts to select. + renormalize (bool): If True, renormalize the top-k weights to sum to 1. + """ + M = gating_output.shape[0] + if is_hip(): + # The MoE kernels are not yet supported on ROCm. + routing_weights = torch.softmax(gating_output, + dim=-1, + dtype=torch.float32) + topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1) + else: + import vllm._moe_C as moe_kernels + + topk_weights = torch.empty(M, + topk, + dtype=torch.float32, + device=gating_output.device) + topk_ids = torch.empty(M, + topk, + dtype=torch.int32, + device=gating_output.device) + token_expert_indicies = torch.empty(M, + topk, + dtype=torch.int32, + device=gating_output.device) + moe_kernels.topk_softmax( + topk_weights, + topk_ids, + token_expert_indicies, + gating_output.float(), # TODO(woosuk): Optimize this. + ) + del token_expert_indicies # Not used. Will be used in the future. + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids + + def get_config_file_name(E: int, N: int) -> str: device_name = torch.cuda.get_device_name().replace(" ", "_") return f"E={E},N={N},device_name={device_name}.json" @@ -286,7 +333,7 @@ def fused_moe( gating_output: torch.Tensor, topk: int, renormalize: bool, - inplace: bool = False, + inplace: bool = True, override_config: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: """ @@ -315,44 +362,13 @@ def fused_moe( assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [ torch.float32, torch.float16, torch.bfloat16 ] M, _ = hidden_states.shape E, N, _ = w1.shape - if is_hip(): - # The MoE kernels are not yet supported on ROCm. - routing_weights = torch.softmax(gating_output, - dim=-1, - dtype=torch.float32) - topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1) - else: - import vllm._moe_C as moe_kernels - - topk_weights = torch.empty(M, - topk, - dtype=torch.float32, - device=hidden_states.device) - topk_ids = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) - token_expert_indicies = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) - moe_kernels.topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output.float(), # TODO(woosuk): Optimize this. - ) - del token_expert_indicies # Not used. Will be used in the future. - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights, topk_ids = fused_topk(gating_output, topk, renormalize) if override_config: config = override_config diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f3d4d1789db2..11a056257624 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -5,7 +5,9 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter +from vllm._C import ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import fused_moe, fused_topk from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) from vllm.model_executor.parallel_utils.parallel_state import ( @@ -44,6 +46,68 @@ def apply_weights(self, """Apply the weights to the input tensor.""" raise NotImplementedError + def create_moe_weights(self, num_experts: int, + input_size_per_partition: int, + output_size_per_partition: int, input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + """Creating moe weights""" + linear_weights = self.create_weights(input_size_per_partition, + output_size_per_partition, + input_size, output_size, + params_dtype) + if num_experts == 1: + return linear_weights + for name, param in tuple(linear_weights.items()): + if isinstance(param, Parameter): + repeat_size = (num_experts, ) + (1, ) * param.dim() + new_param = Parameter(param.unsqueeze(0).repeat(*repeat_size), + requires_grad=False) + set_weight_attrs(new_param, param.__dict__) + linear_weights[name] = new_param + return linear_weights + + def apply_moe_weights(self, w1: Dict[str, + torch.Tensor], w2: Dict[str, + torch.Tensor], + x: torch.Tensor, gating_output: torch.Tensor, + topk: int, renormalize: bool) -> torch.Tensor: + """Apply the weights to the input tensor.""" + routing_weights, selected_experts = fused_topk(gating_output, + topk, + renormalize=renormalize) + final_hidden_states = None + num_experts = gating_output.shape[-1] + for expert_idx in range(num_experts): + w1_expert = { + key: + value[expert_idx] if isinstance(value, torch.Tensor) else value + for key, value in w1.items() + } + w2_expert = { + key: + value[expert_idx] if isinstance(value, torch.Tensor) else value + for key, value in w2.items() + } + expert_mask = (selected_experts == expert_idx) + expert_weights = (routing_weights * expert_mask).sum(dim=-1, + keepdim=True) + hidden_states = self.apply_weights(w1_expert, x) + output_shape = (hidden_states.shape[:-1] + + (hidden_states.shape[-1] // 2, )) + out = torch.empty(output_shape, + dtype=hidden_states.dtype, + device=hidden_states.device) + ops.silu_and_mul(out, hidden_states) + current_hidden_states = self.apply_weights( + w2_expert, out).mul_(expert_weights) + + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states.add_(current_hidden_states) + return final_hidden_states + class UnquantizedLinearMethod(LinearMethodBase): """Linear method without quantization. @@ -78,6 +142,14 @@ def apply_weights(self, return F.linear(x, weight) return F.linear(x, weight, bias) + def apply_moe_weights(self, w1: Dict[str, + torch.Tensor], w2: Dict[str, + torch.Tensor], + x: torch.Tensor, gating_output: torch.Tensor, + topk: int, renormalize: bool) -> torch.Tensor: + return fused_moe(x, w1["weight"], w2["weight"], gating_output, topk, + renormalize) + class ReplicatedLinear(torch.nn.Module): """Replicated linear layer. @@ -161,6 +233,7 @@ def __init__( skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, linear_method: Optional[LinearMethodBase] = None, + num_experts: int = 1, ): super().__init__() @@ -178,9 +251,10 @@ def __init__( if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size_per_partition, self.input_size, - self.output_size, self.params_dtype) + self.num_experts = num_experts + self.linear_weights = self.linear_method.create_moe_weights( + num_experts, self.input_size, self.output_size_per_partition, + self.input_size, self.output_size, self.params_dtype) for name, weight in self.linear_weights.items(): if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) @@ -196,10 +270,20 @@ def __init__( else: self.register_parameter("bias", None) - def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + def weight_loader(self, + param: Parameter, + loaded_weight: torch.Tensor, + expert_id: int = -1): tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) param_data = param.data + if self.num_experts > 1: + if expert_id >= 0: + param_data = param_data[expert_id] + # Loaded weight is packed at expert dim + else: + output_dim = output_dim + 1 + if output_dim is not None: shard_size = param_data.shape[output_dim] start_idx = tp_rank * shard_size @@ -253,19 +337,28 @@ def __init__( skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, linear_method: Optional[LinearMethodBase] = None, + num_experts: int = 1, ): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) super().__init__(input_size, sum(output_sizes), bias, gather_output, - skip_bias_add, params_dtype, linear_method) + skip_bias_add, params_dtype, linear_method, + num_experts) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, - loaded_shard_id: Optional[int] = None): + loaded_shard_id: Optional[int] = None, + expert_id: int = -1): param_data = param.data output_dim = getattr(param, "output_dim", None) + if self.num_experts > 1: + if expert_id >= 0: + param_data = param_data[expert_id] + # Loaded weight is packed at expert dim + elif output_dim is not None: + output_dim = output_dim + 1 if loaded_shard_id is None: # Loaded weight is already packed. if output_dim is None: @@ -506,6 +599,7 @@ def __init__( params_dtype: Optional[torch.dtype] = None, reduce_results: bool = True, linear_method: Optional[LinearMethodBase] = None, + num_experts: int = 1, ): super().__init__() # Keep input parameters @@ -524,9 +618,10 @@ def __init__( if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - self.linear_weights = self.linear_method.create_weights( - self.input_size_per_partition, self.output_size, self.input_size, - self.output_size, self.params_dtype) + self.num_experts = num_experts + self.linear_weights = self.linear_method.create_moe_weights( + num_experts, self.input_size_per_partition, self.output_size, + self.input_size, self.output_size, self.params_dtype) for name, weight in self.linear_weights.items(): if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) @@ -546,10 +641,19 @@ def __init__( else: self.register_parameter("bias", None) - def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + def weight_loader(self, + param: Parameter, + loaded_weight: torch.Tensor, + expert_id: int = -1): tp_rank = get_tensor_model_parallel_rank() input_dim = getattr(param, "input_dim", None) param_data = param.data + if self.num_experts > 1: + if expert_id >= 0: + param_data = param_data[expert_id] + # Loaded weight is packed at expert dim + elif input_dim is not None: + input_dim = input_dim + 1 if input_dim is not None: shard_size = param_data.shape[input_dim] start_idx = tp_rank * shard_size diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 2caef5f1ebf5..b608c8645bcb 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -4,6 +4,8 @@ from torch.nn.parameter import Parameter from vllm._C import ops +from vllm.model_executor.layers.fused_moe import (fused_moe, fused_topk, + moe_align_block_size) from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( @@ -165,3 +167,45 @@ def apply_weights(self, if bias is not None: out = out + bias return out.reshape(out_shape) + + def apply_moe_weights(self, w1: Dict[str, + torch.Tensor], w2: Dict[str, + torch.Tensor], + x: torch.Tensor, gating_output: torch.Tensor, + topk: int, renormalize: bool) -> torch.Tensor: + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 1024 + if FP16_MATMUL_HEURISTIC_CONDITION: + dequant_w1 = ops.awq_dequantize(w1["qweight"], w1["scales"], + w1["qzeros"], 0, 0, + 0).permute(0, 2, 1) + dequant_w2 = ops.awq_dequantize(w2["qweight"], w2["scales"], + w2["qzeros"], 0, 0, + 0).permute(0, 2, 1) + return fused_moe(x, dequant_w1, dequant_w2, gating_output, topk, + renormalize) + + topk_weights, topk_ids = fused_topk(gating_output, topk, renormalize) + (sorted_token_ids, expert_ids, + num_tokens_post_padded) = moe_align_block_size( + topk_ids, 16, w1["qweight"].shape[0]) + + x = x.view(x.shape[0], 1, *x.shape[1:]) + pack_factor = self.quant_config.pack_factor + + gate_up = ops.awq_group_gemm(x, w1["qweight"], w1["scales"], + w1["qzeros"], topk_weights, + sorted_token_ids, expert_ids, + num_tokens_post_padded, False, + pack_factor) + + out = torch.empty((gate_up.shape[:-1] + (gate_up.shape[-1] // 2, )), + dtype=x.dtype, + device=x.device) + ops.silu_and_mul(out, gate_up) + + out = ops.awq_group_gemm(out, w2["qweight"], w2["scales"], + w2["qzeros"], topk_weights, sorted_token_ids, + expert_ids, num_tokens_post_padded, True, + pack_factor) + + return torch.sum(out, dim=1) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 53baf710ed81..a987dd34eb26 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -7,6 +7,8 @@ from torch.nn.parameter import Parameter from vllm._C import ops +from vllm.model_executor.layers.fused_moe import (fused_moe, fused_topk, + moe_align_block_size) from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( @@ -213,3 +215,62 @@ def apply_weights(self, if bias is not None: output = output + bias return output.reshape(out_shape) + + def apply_moe_weights(self, w1: Dict[str, + torch.Tensor], w2: Dict[str, + torch.Tensor], + x: torch.Tensor, gating_output: torch.Tensor, + topk: int, renormalize: bool) -> torch.Tensor: + # shuffle weights for exllama + for w in [w1, w2]: + if w["exllama_state"] == ExllamaState.UNINITIALIZED: + if self.quant_config.desc_act: + w["g_idx"] = torch.argsort(w["g_idx"], + dim=-1).to(torch.int) + else: + w["g_idx"] = torch.empty((w["g_idx"].shape[0], 1), + device="meta") + w["exllama_state"] = ExllamaState.READY + ops.gptq_shuffle(w["qweight"], w["g_idx"], + self.quant_config.weight_bits) + + # Fused moe only supports 4-bit + if self.quant_config.weight_bits != 4: + return super().apply_moe_weights(w1, w2, x, gating_output, topk, + renormalize) + + if x.shape[0] >= 128: + dequant_w1 = ops.dequant_gptq( + w1["qweight"], w1["qzeros"], w1["scales"], w1["g_idx"], + self.quant_config.weight_bits, + w1["exllama_state"] == ExllamaState.READY).permute(0, 2, 1) + dequant_w2 = ops.dequant_gptq( + w2["qweight"], w2["qzeros"], w2["scales"], w2["g_idx"], + self.quant_config.weight_bits, + w2["exllama_state"] == ExllamaState.READY).permute(0, 2, 1) + return fused_moe(x, dequant_w1, dequant_w2, gating_output, topk, + renormalize) + + topk_weights, topk_ids = fused_topk(gating_output, topk, renormalize) + (sorted_token_ids, expert_ids, + num_tokens_post_padded) = moe_align_block_size( + topk_ids, 8, w1["qweight"].shape[0]) + + x = x.view(x.shape[0], 1, *x.shape[1:]) + gate_up = ops.group_gptq_gemm( + x, w1["qweight"], w1["qzeros"], w1["scales"], w1["g_idx"], + topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, + False, w1["exllama_state"] == ExllamaState.READY) + + out = torch.empty((gate_up.shape[:-1] + (gate_up.shape[-1] // 2, )), + dtype=x.dtype, + device=x.device) + ops.silu_and_mul(out, gate_up) + + out = ops.group_gptq_gemm(out, w2["qweight"], w2["qzeros"], + w2["scales"], w2["g_idx"], topk_weights, + sorted_token_ids, expert_ids, + num_tokens_post_padded, True, + w2["exllama_state"] == ExllamaState.READY) + + return torch.sum(out, dim=1) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 958c9b97f725..ba29a0e1c060 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -27,11 +27,6 @@ def _set_default_torch_dtype(dtype: torch.dtype): def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]: architectures = getattr(model_config.hf_config, "architectures", []) - # Special handling for quantized Mixtral. - # FIXME(woosuk): This is a temporary hack. - if (model_config.quantization is not None - and "MixtralForCausalLM" in architectures): - architectures = ["QuantMixtralForCausalLM"] for arch in architectures: model_cls = ModelRegistry.load_model_cls(arch) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index b5c7e44de619..2ffda607b0aa 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -37,7 +37,6 @@ "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), - "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 2a2182ff4eba..cd0d7dc650ae 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -29,13 +29,16 @@ from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm +# yapf conflicts with isort for this block +# yapf: disable from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, - RowParallelLinear) + RowParallelLinear, + UnquantizedLinearMethod) +# yapf: enable from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler @@ -96,20 +99,20 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.n_routed_experts = config.n_routed_experts self.top_k = config.num_experts_per_tok - if self.tp_size > self.n_routed_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.n_routed_experts}.") - - self.experts = nn.ModuleList([ - DeepseekMLP(hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - hidden_act=config.hidden_act, - linear_method=linear_method, - reduce_results=False) - for idx in range(self.n_routed_experts) - ]) - self.pack_params() + self.linear_method = linear_method + if self.linear_method is None: + self.linear_method = UnquantizedLinearMethod() + + self.w1 = MergedColumnParallelLinear( + config.hidden_size, [config.moe_intermediate_size] * 2, + bias=False, + linear_method=linear_method, + num_experts=self.n_routed_experts) + self.w2 = RowParallelLinear(config.moe_intermediate_size, + config.hidden_size, + bias=False, + linear_method=linear_method, + num_experts=self.n_routed_experts) self.gate = ReplicatedLinear(config.hidden_size, self.n_routed_experts, @@ -127,25 +130,6 @@ def __init__( reduce_results=False, ) - def pack_params(self): - w1 = [] - w2 = [] - for expert in self.experts: - w1.append(expert.gate_up_proj.weight) - w2.append(expert.down_proj.weight) - self.w1 = torch._utils._flatten_dense_tensors(w1) - w1s = torch._utils._unflatten_dense_tensors(self.w1, w1) - for data, param in zip(w1s, w1): - param.data = data - self.w1 = self.w1.view(len(w1), *w1s[0].shape) - - self.w2 = torch._utils._flatten_dense_tensors(w2) - w2s = torch._utils._unflatten_dense_tensors(self.w2, w2) - for data, param in zip(w2s, w2): - param.data = data - - self.w2 = self.w2.view(len(w2), *w2s[0].shape) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) @@ -153,13 +137,15 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = fused_moe(hidden_states, - self.w1, - self.w2, - router_logits, - self.top_k, - renormalize=self.config.norm_topk_prob, - inplace=True) + + final_hidden_states = self.linear_method.apply_moe_weights( + self.w1.linear_weights, + self.w2.linear_weights, + hidden_states, + router_logits, + self.top_k, + renormalize=self.config.norm_topk_prob, + ) if self.config.n_shared_experts is not None: final_hidden_states = final_hidden_states + shared_output @@ -406,8 +392,20 @@ def load_weights(self, ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), + ("mlp.gate_up_proj", "mlp.gate_proj", 0), + ("mlp.gate_up_proj", "mlp.up_proj", 1), + ("shared_experts.gate_up_proj", "shared_experts.gate_proj", 0), + ("shared_experts.gate_up_proj", "shared_experts.up_proj", 1), + ] + + expert_params_mapping = [ + # (param_name, weight_name, shard_id, expert_id) + ("w1" if weight_name in ["gate_proj", "up_proj"] else "w2", + f"experts.{expert_id}.{weight_name}", shard_id, expert_id) + for expert_id in range(self.config.n_routed_experts) + for weight_name, shard_id in [("gate_proj", + 0), ("up_proj", + 1), ("down_proj", None)] ] params_dict = dict(self.named_parameters()) @@ -426,23 +424,35 @@ def load_weights(self, # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Skip experts that are not assigned to this worker. - if (("mlp.experts." in name or "mlp.shared_experts." in name) - and name not in params_dict): - continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip experts that are not assigned to this worker. - if (("mlp.experts." in name or "mlp.shared_experts." in name) - and name not in params_dict): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + for (param_name, weight_name, shard_id, + expert_id) in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + if shard_id is None: + weight_loader(param, + loaded_weight, + expert_id=expert_id) + else: + weight_loader(param, + loaded_weight, + shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 429bc8109b9f..60f465f02c6e 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -29,12 +29,16 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig -from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm +# yapf conflicts with isort for this block +# yapf: disable from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, - RowParallelLinear) + RowParallelLinear, + UnquantizedLinearMethod) +# yapf: enable from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler @@ -43,9 +47,8 @@ from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_reduce) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -66,8 +69,8 @@ def __init__( top_k: int, hidden_size: int, intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, tp_size: Optional[int] = None, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.tp_size = tp_size or get_tensor_model_parallel_world_size() @@ -75,63 +78,40 @@ def __init__( self.top_k = top_k self.hidden_size = hidden_size self.intermediate_size = intermediate_size // self.tp_size - - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype + self.linear_method = linear_method + if self.linear_method is None: + self.linear_method = UnquantizedLinearMethod() self.gate = ReplicatedLinear(self.hidden_size, self.num_total_experts, bias=False, - params_dtype=self.params_dtype, linear_method=None) - self.ws = nn.Parameter( - torch.empty(self.num_total_experts, - 2 * self.intermediate_size, - self.hidden_size, - device="cuda", - dtype=self.params_dtype)) - self.w2s = nn.Parameter( - torch.empty(self.num_total_experts, - self.hidden_size, - self.intermediate_size, - device="cuda", - dtype=self.params_dtype)) - - set_weight_attrs(self.ws, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.w2s, { - "weight_loader": self.weight_loader, - }) - - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, - weight_name: str, expert_id: int): - tp_rank = get_tensor_model_parallel_rank() - param_data = param.data - shard_size = self.intermediate_size - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - if weight_name.endswith("w1.weight"): - param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] - if weight_name.endswith("w3.weight"): - param_data[expert_id, - shard_size:2 * shard_size, :] = loaded_weight[shard, :] - if weight_name.endswith("w2.weight"): - param_data[expert_id, :, :] = loaded_weight[:, shard] + self.ws = MergedColumnParallelLinear(hidden_size, + [intermediate_size] * 2, + bias=False, + linear_method=linear_method, + num_experts=num_experts) + self.w2s = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method, + num_experts=num_experts) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = fused_moe(hidden_states, - self.ws, - self.w2s, - router_logits, - self.top_k, - renormalize=True, - inplace=True) + + final_hidden_states = self.linear_method.apply_moe_weights( + self.ws.linear_weights, + self.w2s.linear_weights, + hidden_states, + router_logits, + self.top_k, + renormalize=True, + ) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( @@ -240,7 +220,8 @@ def __init__( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size) + intermediate_size=config.intermediate_size, + linear_method=linear_method) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -407,11 +388,11 @@ def load_weights(self, ] expert_params_mapping = [ - # (param_name, weight_name, expert_id) + # (param_name, weight_name, shard_id, expert_id) ("ws" if weight_name in ["w1", "w3"] else "w2s", - f"experts.{expert_id}.{weight_name}.weight", expert_id) + f"experts.{expert_id}.{weight_name}", shard_id, expert_id) for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] + for weight_name, shard_id in [("w1", 0), ("w3", 1), ("w2", None)] ] params_dict = dict(self.named_parameters()) @@ -436,16 +417,24 @@ def load_weights(self, weight_loader(param, loaded_weight, shard_id) break else: - for param_name, weight_name, expert_id in expert_params_mapping: + for (param_name, weight_name, shard_id, + expert_id) in expert_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - expert_id=expert_id) + if shard_id is None: + weight_loader(param, + loaded_weight, + expert_id=expert_id) + else: + weight_loader(param, + loaded_weight, + shard_id, + expert_id=expert_id) break else: # Skip loading extra bias for GPTQ models. diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py deleted file mode 100644 index 75f86bc134ee..000000000000 --- a/vllm/model_executor/models/mixtral_quant.py +++ /dev/null @@ -1,413 +0,0 @@ -# coding=utf-8 -# Adapted from -# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# Copyright 2023 The vLLM team. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only Mixtral model.""" -from typing import List, Optional - -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn -from transformers import MixtralConfig - -from vllm.attention import Attention, AttentionMetadata -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) -from vllm.sequence import SamplerOutput - - -class MixtralMLP(nn.Module): - - def __init__( - self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.num_experts = num_experts - self.ffn_dim = intermediate_size - self.hidden_dim = hidden_size - - self.w1 = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - linear_method=linear_method) - self.w2 = ReplicatedLinear(self.ffn_dim, - self.hidden_dim, - bias=False, - linear_method=linear_method) - self.w3 = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - linear_method=linear_method) - - # TODO: Use vllm's SiluAndMul - self.act_fn = nn.SiLU() - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - w1_out, _ = self.w1(hidden_states) - w1_out = self.act_fn(w1_out) - w3_out, _ = self.w3(hidden_states) - current_hidden_states = w1_out * w3_out - current_hidden_states, _ = self.w2(current_hidden_states) - return current_hidden_states - - -class MixtralMoE(nn.Module): - - def __init__( - self, - config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.config = config - self.rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() - self.num_total_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - if self.tp_size > self.num_total_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.num_total_experts}.") - # Split experts equally between ranks - self.expert_indicies = np.array_split(range( - self.num_total_experts), self.tp_size)[self.rank].tolist() - if not self.expert_indicies: - raise ValueError( - f"Rank {self.rank} has no experts assigned to it.") - - self.experts = nn.ModuleList([ - MixtralMLP(self.num_total_experts, - config.hidden_size, - config.intermediate_size, - linear_method=linear_method) - if idx in self.expert_indicies else None - for idx in range(self.num_total_experts) - ]) - self.gate = ReplicatedLinear(config.hidden_size, - self.num_total_experts, - bias=False, - linear_method=None) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, - self.top_k, - dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - - final_hidden_states = None - for expert_idx in self.expert_indicies: - expert_layer = self.experts[expert_idx] - expert_mask = (selected_experts == expert_idx) - expert_weights = (routing_weights * expert_mask).sum(dim=-1, - keepdim=True) - - current_hidden_states = expert_layer(hidden_states).mul_( - expert_weights) - if final_hidden_states is None: - final_hidden_states = current_hidden_states - else: - final_hidden_states.add_(current_hidden_states) - - return tensor_model_parallel_all_reduce(final_hidden_states).view( - num_tokens, hidden_dim) - - -class MixtralAttention(nn.Module): - - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - linear_method: Optional[LinearMethodBase] = None, - sliding_window: Optional[int] = None) -> None: - super().__init__() - self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta - self.sliding_window = sliding_window - - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=False, - linear_method=linear_method, - ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - linear_method=linear_method, - ) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position, - base=int(self.rope_theta), - is_neox_style=True, - ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, - ) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) - output, _ = self.o_proj(attn_output) - return output - - -class MixtralDecoderLayer(nn.Module): - - def __init__( - self, - config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - # Requires transformers > 4.32.0 - rope_theta = getattr(config, "rope_theta", 10000) - self.self_attn = MixtralAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - max_position=config.max_position_embeddings, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - sliding_window=config.sliding_window, - linear_method=linear_method) - self.block_sparse_moe = MixtralMoE(config=config, - linear_method=linear_method) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - residual: Optional[torch.Tensor], - ) -> torch.Tensor: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.block_sparse_moe(hidden_states) - return hidden_states, residual - - -class MixtralModel(nn.Module): - - def __init__( - self, - config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - self.layers = nn.ModuleList([ - MixtralDecoderLayer(config, linear_method=linear_method) - for _ in range(config.num_hidden_layers) - ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): - layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i], attn_metadata, - residual) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - -class MixtralForCausalLM(nn.Module): - - def __init__( - self, - config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.config = config - self.linear_method = linear_method - self.model = MixtralModel(config, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) - return hidden_states - - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, - sampling_metadata) - return logits - - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, - cache_dir, - load_format, - revision, - fall_back_to_pt=False): - if "rotary_emb.inv_freq" in name: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip experts that are not assigned to this worker. - if ("block_sparse_moe.experts." in name - and name not in params_dict): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 6b4a74198fd5..d9e06bf6ce1a 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -31,13 +31,16 @@ from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm +# yapf conflicts with isort for this block +# yapf: disable from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, - RowParallelLinear) + RowParallelLinear, + UnquantizedLinearMethod) +# yapf: enable from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler @@ -103,15 +106,9 @@ def __init__( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {self.n_routed_experts}.") - self.experts = nn.ModuleList([ - Qwen2MoeMLP(hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - hidden_act=config.hidden_act, - linear_method=linear_method, - reduce_results=False) - for idx in range(self.n_routed_experts) - ]) - self.pack_params() + self.linear_method = linear_method + if self.linear_method is None: + self.linear_method = UnquantizedLinearMethod() self.gate = ReplicatedLinear(config.hidden_size, self.n_routed_experts, @@ -131,24 +128,16 @@ def __init__( 1, bias=False) - def pack_params(self): - w1 = [] - w2 = [] - for expert in self.experts: - w1.append(expert.gate_up_proj.weight) - w2.append(expert.down_proj.weight) - self.w1 = torch._utils._flatten_dense_tensors(w1) - w1s = torch._utils._unflatten_dense_tensors(self.w1, w1) - for data, param in zip(w1s, w1): - param.data = data - self.w1 = self.w1.view(len(w1), *w1s[0].shape) - - self.w2 = torch._utils._flatten_dense_tensors(w2) - w2s = torch._utils._unflatten_dense_tensors(self.w2, w2) - for data, param in zip(w2s, w2): - param.data = data - - self.w2 = self.w2.view(len(w2), *w2s[0].shape) + self.w1 = MergedColumnParallelLinear( + config.hidden_size, [config.moe_intermediate_size] * 2, + bias=False, + linear_method=linear_method, + num_experts=self.n_routed_experts) + self.w2 = RowParallelLinear(config.moe_intermediate_size, + config.hidden_size, + bias=False, + linear_method=linear_method, + num_experts=self.n_routed_experts) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape @@ -162,13 +151,15 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = fused_moe(hidden_states, - self.w1, - self.w2, - router_logits, - self.top_k, - renormalize=self.config.norm_topk_prob, - inplace=True) + + final_hidden_states = self.linear_method.apply_moe_weights( + self.w1.linear_weights, + self.w2.linear_weights, + hidden_states, + router_logits, + self.top_k, + renormalize=self.config.norm_topk_prob, + ) if shared_output is not None: final_hidden_states = final_hidden_states + shared_output @@ -415,8 +406,20 @@ def load_weights(self, ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), + ("mlp.gate_up_proj", "mlp.gate_proj", 0), + ("mlp.gate_up_proj", "mlp.up_proj", 1), + ("shared_expert.gate_up_proj", "shared_expert.gate_proj", 0), + ("shared_expert.gate_up_proj", "shared_expert.up_proj", 1), + ] + + expert_params_mapping = [ + # (param_name, weight_name, shard_id, expert_id) + ("w1" if weight_name in ["gate_proj", "up_proj"] else "w2", + f"experts.{expert_id}.{weight_name}", shard_id, expert_id) + for expert_id in range(self.config.num_experts) + for weight_name, shard_id in [("gate_proj", + 0), ("up_proj", + 1), ("down_proj", None)] ] params_dict = dict(self.named_parameters()) @@ -435,23 +438,35 @@ def load_weights(self, # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Skip experts that are not assigned to this worker. - if (("mlp.experts." in name or "mlp.shared_expert." in name) - and name not in params_dict): - continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip experts that are not assigned to this worker. - if (("mlp.experts." in name or "mlp.shared_expert." in name) - and name not in params_dict): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + for (param_name, weight_name, shard_id, + expert_id) in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + if shard_id is None: + weight_loader(param, + loaded_weight, + expert_id=expert_id) + else: + weight_loader(param, + loaded_weight, + shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight)