diff --git a/awq_ext/pybind_awq.cpp b/awq_ext/pybind_awq.cpp index 4b74994..57cdb34 100644 --- a/awq_ext/pybind_awq.cpp +++ b/awq_ext/pybind_awq.cpp @@ -4,13 +4,20 @@ #include "quantization/gemm_cuda.h" #include "quantization/gemv_cuda.h" #include "position_embedding/pos_encoding.h" +#include "vllm/moe_alig_block.h" +#include "vllm/activation.h" +#include "vllm/topk_softmax_kernels.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("layernorm_forward_cuda", &layernorm_forward_cuda, "FasterTransformer layernorm kernel"); m.def("gemm_forward_cuda", &gemm_forward_cuda, "Quantized GEMM kernel."); + m.def("grouped_gemm_forward", &grouped_gemm_forward, "Quantized grouped GEMM kernel."); m.def("gemmv2_forward_cuda", &gemmv2_forward_cuda, "Quantized v2 GEMM kernel."); m.def("gemv_forward_cuda", &gemv_forward_cuda, "Quantized GEMV kernel."); m.def("rotary_embedding_neox", &rotary_embedding_neox, "Apply GPT-NeoX style rotary embedding to query and key"); m.def("dequantize_weights_cuda", &dequantize_weights_cuda, "Dequantize weights."); + m.def("moe_alig_block_size", &moe_alig_block_size, "Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."); + m.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); + m.def("topk_softmax", &topk_softmax, "Computes fused topk and softmax operation."); } \ No newline at end of file diff --git a/awq_ext/quantization/gemm_cuda.h b/awq_ext/quantization/gemm_cuda.h index 63be8bf..afc8165 100644 --- a/awq_ext/quantization/gemm_cuda.h +++ b/awq_ext/quantization/gemm_cuda.h @@ -3,6 +3,18 @@ torch::Tensor gemm_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters); +torch::Tensor grouped_gemm_forward( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + torch::Tensor _topk_weights, + torch::Tensor _sorted_token_ids_ptr, + torch::Tensor _expert_ids_ptr, + torch::Tensor _num_tokens_post_padded, + bool mul_weights, + int split_k_iters); + torch::Tensor gemmv2_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, int group_size, int split_k_iters); diff --git a/awq_ext/quantization/gemm_cuda_gen.cu b/awq_ext/quantization/gemm_cuda_gen.cu index 558b5c6..c1d9c4e 100644 --- a/awq_ext/quantization/gemm_cuda_gen.cu +++ b/awq_ext/quantization/gemm_cuda_gen.cu @@ -733,8 +733,15 @@ __global__ void __launch_bounds__(64) dequantize_weights(int* __restrict__ B, // int* __restrict__ zeros, // 32x64 32 rows 64 cols half* __restrict__ C, // 4096x512 4096 rows 512 cols int G, - bool dbg) + int in_c, + int out_c) { + if (blockIdx.z > 0) { + B = B + blockIdx.z * in_c * out_c / 8; + scaling_factors = scaling_factors + blockIdx.z * in_c * out_c / G; + zeros = zeros + blockIdx.z * in_c * out_c / G / 8; + C = C + blockIdx.z * in_c * out_c; + } int j_factors1 = 4; int row_stride2 = 4; int split_k_iters = 1; @@ -755,25 +762,11 @@ __global__ void __launch_bounds__(64) dequantize_weights(int* __restrict__ B, // int index2 = col + row * N; int* B_ptr2 = B + index2; - if (dbg) { - printf("\n-------- x %d - y %d --------\n", col, row); - printf("- %d-%d - N %d index1 %d \n", col, row, N, index2); - printf("- %d-%d - B %d \n", col, row, *B_ptr2); - } - int index3 = col + (int)(row / G) * N; int* zeros_ptr2 = zeros + index3; int index4 = 8 * col + (int)(row / G) * N * 8; // + i (<8) half* scaling_factors_ptr2 = scaling_factors + index4; - if (dbg) { - printf("- %d-%d - zeros[%d] %d \n", col, row, index3, *zeros_ptr2); - printf("- %d-%d - N %d index4 %d \n", col, row, N, index4); - for (int i=0; i<8; ++i) { - printf("- %d-%d - scale[%d] %f \n", col, row, index4 + i, __half2float(*(scaling_factors_ptr2+i))); - } - } - uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2); uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); @@ -793,15 +786,321 @@ int j=0; *(uint4*)(B_shared_ptr2 + j) = B_loaded_fp16; - if (dbg) { - for (int i=0; i<8; ++i) { - printf("- %d-%d - B_shared_ptr2[%d] %f \n", col, row, i, __half2float(*(B_shared_ptr2 + i)) ); + for (int i=0; i<8; ++i) { + *(C_ptr2 + i) = B_shared[i]; + } +} + +template +__global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32( + int G, + int split_k_iters, + half* __restrict__ A, + int* __restrict__ B, + half* __restrict__ scaling_factors, + int* __restrict__ zeros, + const float* __restrict__ topk_weights, + const int* __restrict__ sorted_token_ids_ptr, + const int* __restrict__ expert_ids_ptr, + const int* __restrict__ num_tokens_post_padded, + const int num_valid_tokens, + const int top_k, + const int expert_num, + int pad_M, + int M, + int IC, + int OC, + half* __restrict__ C) +{ + // Only support matrix n = 64 or 128 + assert(N == 64 || N == 128); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + assert(false); +#else + int num_tokens = *num_tokens_post_padded; + int j_factors1 = ((OC + N - 1) / N); + int blockIdx_x = 0; + int blockIdx_y = blockIdx.x % ((pad_M + 16 - 1) / 16 * j_factors1); + int blockIdx_z = blockIdx.x / ((pad_M + 16 - 1) / 16 * j_factors1); + int block = blockIdx_y / j_factors1; + if (block * 16 >= num_tokens) return; + + static constexpr uint32_t ZERO = 0x0; + float C_warp[32]; + __shared__ half A_shared[16 * (32 + 8)]; + __shared__ half B_shared[32 * (N + 8)]; + + __shared__ half scaling_factors_shared[N]; + __shared__ half zeros_shared[N]; + + half A_shared_warp[8]; + half B_shared_warp[N / 4]; + for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) { + for (int i = 0; i < 8; ++i) { + C_warp[(j_0_4_init * 8) + i] = 0.0; } } - for (int i=0; i<8; ++i) { - *(C_ptr2 + i) = B_shared[i]; + static constexpr int row_stride_warp = 32 * 8 / 32; + static constexpr int row_stride = 2 * 32 * 8 / N; + bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N; + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 + + int row = (block * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32); + int token_id = sorted_token_ids_ptr[row]; + bool ld_A_flag = (token_id < num_valid_tokens); + half* A_ptr = A + token_id / top_k * IC + (((int)threadIdx.x) % (32 / 8)) * 8; + + int expert_id = expert_ids_ptr[block]; + B = B + OC * IC / 8 * expert_id; + scaling_factors = scaling_factors + OC * IC / G * expert_id; + zeros = zeros + OC * IC / G / 8 * expert_id; + + int* B_ptr = B + + ((int)threadIdx.y) * (OC / 8) * (256 / N) + + (((int)threadIdx.x) / (N / 8)) * (OC / 8) + + (((int)blockIdx_y) % j_factors1) * (N / 8) + + (((int)threadIdx.x) % (N / 8)) * 1; + // Why * 1 in the above line? + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + + (((int)threadIdx.x) % (32 / 8) ) * 8; + + half* B_shared_ptr = B_shared + + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) + + (((int)threadIdx.x) / (N / 8)) * (N + 8) + + (((int)threadIdx.x) % (N / 8)) * 8; + + int* zeros_ptr = zeros + + (((int)blockIdx_y) % j_factors1) * (N / 8) + + ((int)threadIdx.x) % (N / 8); + + half* scaling_factors_ptr = scaling_factors + + (((int)blockIdx_y) % j_factors1) * N + + (((int)threadIdx.x) % (N / 8)) * 8; + + half* C_ptr = C + + static_cast(blockIdx_z) * M * OC * expert_num // blockIdz.x -> split_k dim + + (((int)blockIdx_y) % j_factors1) * N + + ((int)threadIdx.y) * (N / 2) + + (((int)threadIdx.x) % 4) * 2; + + // preload s.f. and zeros + int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; + if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1; + for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { + int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; + __syncthreads(); + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 + if (ld_A_flag) + { + *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); + } + else + { + *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); + } + + uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); + uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); + + int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); + + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { + + uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); + + // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + + // write back + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16; + } + __syncthreads(); + + for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) { + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + + + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) + : "r"(addr) + ); + } + + for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr) + ); + } + } + for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } +#else + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } + +#endif + } + } } + +// TODO: Shang: Hoist loop invariance. + for (int ax1_0_1 = 0; ax1_0_1 < N / 32; ++ax1_0_1) { + for (int local_id = 0; local_id < 8; ++local_id) { + int row_offset = block * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; + int token_id = sorted_token_ids_ptr[row_offset]; + if (token_id < num_valid_tokens) + { + float value = C_warp[(ax1_0_1 * 8) + local_id]; + if (topk_weights) { + value = value * topk_weights[token_id]; + } + *(C_ptr + ax1_0_1 * 16 + token_id * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(value); + } + } + } +#endif +} + + +torch::Tensor grouped_gemm_forward( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + torch::Tensor _topk_weights, + torch::Tensor _sorted_token_ids_ptr, + torch::Tensor _expert_ids_ptr, + torch::Tensor _num_tokens_post_padded, + bool mul_weights, + int split_k_iters) +{ + int num_in_feats = _in_feats.size(0); + int pad_num_in_feats = _sorted_token_ids_ptr.size(0); + int num_in_channels = _in_feats.size(2); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); + + auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); + int num_experts = _topk_weights.size(1); + int top_k = num_experts / _in_feats.size(1); + int group_size = num_in_channels / _scaling_factors.size(1); + + at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _topk_weights.size(1), _kernel.size(2) * 8}, options); + int num_out_channels = _out_feats.size(-1); + + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto topk_weights = mul_weights ? reinterpret_cast(_topk_weights.data_ptr()) : nullptr; + auto sorted_token_ids_ptr = reinterpret_cast(_sorted_token_ids_ptr.data_ptr()); + auto expert_ids_ptr = reinterpret_cast(_expert_ids_ptr.data_ptr()); + auto num_tokens_post_padded = reinterpret_cast(_num_tokens_post_padded.data_ptr()); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (num_out_channels % 128 == 0) + { + int j_factors1 = num_out_channels / 128 / 1; + dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + group_gemm_forward_4bit_cuda_m16nXk32<128><<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, + topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded, + _topk_weights.numel(), top_k, num_experts, pad_num_in_feats, + num_in_feats, num_in_channels, num_out_channels, out_feats); + } + else if (num_out_channels % 64 == 0) + { + int j_factors1 = num_out_channels / 64 / 1; + dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + group_gemm_forward_4bit_cuda_m16nXk32<64><<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, + topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded, + _topk_weights.numel(), top_k, num_experts, pad_num_in_feats, + num_in_feats, num_in_channels, num_out_channels, out_feats); + } + return _out_feats.sum(0); } // Dequantization to fp16 @@ -815,10 +1114,11 @@ torch::Tensor dequantize_weights_cuda( int thy, bool dbg) { - int in_c = _kernel.size(0); - int qout_c = _kernel.size(1); + int in_c = _kernel.dim() == 2 ? _kernel.size(0) : _kernel.size(1); + int qout_c = _kernel.dim() == 2 ? _kernel.size(1) : _kernel.size(2); + int num_experts = _kernel.dim() == 2 ? 1 : _kernel.size(0); int out_c = qout_c * 8; - int G = in_c / _scaling_factors.size(0); + int G = in_c / (_kernel.dim() == 2 ? _scaling_factors.size(0) : _scaling_factors.size(1)); int x_thread = thx; int y_thread = thy; @@ -844,17 +1144,22 @@ torch::Tensor dequantize_weights_cuda( const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); - at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); // row, col 4096x512 + at::Tensor _de_kernel; + if (num_experts == 1) { + _de_kernel = torch::empty({in_c, out_c}, options); + } else { + _de_kernel = torch::empty({num_experts, in_c, out_c}, options); + } auto kernel = reinterpret_cast(_kernel.data_ptr()); auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); auto zeros = reinterpret_cast(_zeros.data_ptr()); - dim3 num_blocks(x_blocks, y_blocks); + dim3 num_blocks(x_blocks, y_blocks, num_experts); dim3 threads_per_block(x_thread, y_thread); // col, row 64x4096 - dequantize_weights<<>>(kernel, scaling_factors, zeros, de_kernel, G, dbg); + dequantize_weights<<>>(kernel, scaling_factors, zeros, de_kernel, G, in_c, out_c); return _de_kernel; } diff --git a/awq_ext/vllm/activation.cu b/awq_ext/vllm/activation.cu new file mode 100644 index 0000000..9976767 --- /dev/null +++ b/awq_ext/vllm/activation.cu @@ -0,0 +1,56 @@ +#include +#include +#include + +#define VLLM_LDG(arg) *(arg) + +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +template +__device__ __forceinline__ T silu(const T& x) { + // x * sigmoid(x) + return (T) (((float) x) / (1.0f + expf((float) -x))); +} + +template +__global__ void silu_and_mul_kernel( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const int d) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); + const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); + out[token_idx * d + idx] = silu(x) * y; + } +} + + +void silu_and_mul( + torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + int64_t num_tokens = input.numel() / input.size(-1); + int d = input.size(-1) / 2; + + dim3 grid(num_tokens); + dim3 block(std::min(d, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), + "silu_and_mul_kernel", + [&] { + silu_and_mul_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + d); + }); +} \ No newline at end of file diff --git a/awq_ext/vllm/activation.h b/awq_ext/vllm/activation.h new file mode 100644 index 0000000..41ab653 --- /dev/null +++ b/awq_ext/vllm/activation.h @@ -0,0 +1,3 @@ +void silu_and_mul( + torch::Tensor& out, + torch::Tensor& input); \ No newline at end of file diff --git a/awq_ext/vllm/moe_alig_block.cu b/awq_ext/vllm/moe_alig_block.cu new file mode 100644 index 0000000..63578e5 --- /dev/null +++ b/awq_ext/vllm/moe_alig_block.cu @@ -0,0 +1,91 @@ +#include +#include + +#include +#include + +const static size_t NUM_MAX_EXPERTS = 64; + +#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) + +#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) + +template +__global__ void moe_alig_block_size_kernel(scalar_t *__restrict__ topk_ids, + int32_t *sorted_token_ids, + int32_t *expert_ids, + int32_t *total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, + size_t numel) { + const size_t tokens_per_thread = ((numel + blockDim.x - 1) / blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + __shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS]; + __shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1]; + for(int i = 0;i < num_experts;i++){ + tokens_cnts[threadIdx.x + 1][i] = 0; + } + + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + ++tokens_cnts[threadIdx.x + 1][topk_ids[i]]; + } + + __syncthreads(); + + tokens_cnts[0][threadIdx.x] = 0; + for(int i=1;i<=blockDim.x;++i){ + tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x]; + } + + __syncthreads(); + + if(threadIdx.x ==0){ + cumsum[0] = 0; + for(int i=1;i<=num_experts;++i){ + cumsum[i] = cumsum[i-1] + (tokens_cnts[blockDim.x][i - 1] + block_size - 1) / block_size * block_size; + } + *total_tokens_post_pad = cumsum[num_experts]; + } + + __syncthreads(); + + for(int i= cumsum[threadIdx.x];i<<<1, num_experts, 0, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + num_experts, + block_size, + topk_ids.numel()); + }); +} \ No newline at end of file diff --git a/awq_ext/vllm/moe_alig_block.h b/awq_ext/vllm/moe_alig_block.h new file mode 100644 index 0000000..1d80552 --- /dev/null +++ b/awq_ext/vllm/moe_alig_block.h @@ -0,0 +1,8 @@ +void moe_alig_block_size( + torch::Tensor topk_ids, + int num_experts, + int block_size, + torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad + ); \ No newline at end of file diff --git a/awq_ext/vllm/topk_softmax_kernels.cu b/awq_ext/vllm/topk_softmax_kernels.cu new file mode 100644 index 0000000..37cdb1c --- /dev/null +++ b/awq_ext/vllm/topk_softmax_kernels.cu @@ -0,0 +1,493 @@ +/* + * Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu + * Copyright (c) 2024, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include +#include + +static constexpr int WARP_SIZE = 32; + +/// Aligned array type +template < + typename T, + /// Number of elements in the array + int N, + /// Alignment requirement in bytes + int Alignment = sizeof(T) * N +> +class alignas(Alignment) AlignedArray { + float data[N]; +}; + +// ====================== Softmax things =============================== +// We have our own implementation of softmax here so we can support transposing the output +// in the softmax kernel when we extend this module to support expert-choice routing. +template +__launch_bounds__(TPB) __global__ + void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols) +{ + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + __shared__ float normalizing_factor; + __shared__ float float_max; + + const int thread_row_offset = blockIdx.x * num_cols; + + cub::Sum sum; + float threadData(-FLT_MAX); + + // Don't touch finished rows. + if ((finished != nullptr) && finished[blockIdx.x]) + { + return; + } + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) + { + const int idx = thread_row_offset + ii; + threadData = max(static_cast(input[idx]), threadData); + } + + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + if (threadIdx.x == 0) + { + float_max = maxElem; + } + __syncthreads(); + + threadData = 0; + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) + { + const int idx = thread_row_offset + ii; + threadData += exp((static_cast(input[idx]) - float_max)); + } + + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); + + if (threadIdx.x == 0) + { + normalizing_factor = 1.f / Z; + } + __syncthreads(); + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) + { + const int idx = thread_row_offset + ii; + const float val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; + output[idx] = val; + } +} + +template +__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output, + int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert) +{ + + using cub_kvp = cub::KeyValuePair; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + const int num_rows = gridDim.x; + const int block_row = blockIdx.x; + + const bool row_is_active = finished ? !finished[block_row] : true; + const int thread_read_offset = blockIdx.x * num_experts; + for (int k_idx = 0; k_idx < k; ++k_idx) + { + thread_kvp.key = 0; + thread_kvp.value = -1.f; // This is OK because inputs are probabilities + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) + { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs_after_softmax[idx]; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) + { + const int prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) + { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) + { + // Ignore experts the node isn't responsible for with expert parallelism + const int expert = result_kvp.key; + const bool node_uses_expert = expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + const int idx = k * block_row + k_idx; + output[idx] = result_kvp.value; + indices[idx] = should_process_row ? (expert - start_expert) : num_experts; + assert(indices[idx] >= 0); + source_rows[idx] = k_idx * num_rows + block_row; + } + __syncthreads(); + } +} + +// ====================== TopK softmax things =============================== + +/* + A Top-K gating softmax written to exploit when the number of experts in the MoE layers + are a small power of 2. This allows us to cleanly share the rows among the threads in + a single warp and eliminate communication between warps (so no need to use shared mem). + + It fuses the softmax, max and argmax into a single kernel. + + Limitations: + 1) This implementation is intended for when the number of experts is a small power of 2. + 2) This implementation assumes k is small, but will work for any k. +*/ + +template +__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ + void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices, + int* source_rows, const int k, const int start_expert, const int end_expert) +{ + // We begin by enforcing compile time assertions and setting up compile time constants. + static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); + static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); + static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); + static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); + + // Number of bytes each thread pulls in per load + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static constexpr int ELTS_PER_ROW = NUM_EXPERTS; + static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; + static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; + + // Restrictions based on previous section. + static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); + static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); + static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2"); + static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size"); + + // We have NUM_EXPERTS elements per row. We specialize for small #experts + static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; + static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; + static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; + + // Restrictions for previous section. + static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp"); + + // ===================== From this point, we finally start computing run-time variables. ======================== + + // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps. + // This, each block processes a chunk of rows. We start by computing the start row for each block. + const int cta_base_row = blockIdx.x * ROWS_PER_CTA; + + // Now, using the base row per thread block, we compute the base row per warp. + const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; + + // The threads in a warp are split into sub-groups that will work on a row. + // We compute row offset for each thread sub-group + const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; + const int thread_row = warp_base_row + thread_row_in_warp; + + // Threads with indices out of bounds should early exit here. + if (thread_row >= num_rows) + { + return; + } + const bool row_is_active = finished ? !finished[thread_row] : true; + + // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the + // row it will read. + const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW; + + // Now, we compute the group each thread belong to in order to determine the first column to start loads. + const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; + const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; + const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + + // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, + // this can support all powers of 2 up to 16. + // NOTE(woosuk): The original implementation uses CUTLASS aligned array here. + // We defined our own aligned array and use it here to avoid the dependency on CUTLASS. + using AccessType = AlignedArray; + + // Finally, we pull in the data from global mem + float row_chunk[VPT]; + AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk); + const AccessType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) + { + row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + } + + // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just + // convert to float afterwards for the exp + sum reduction. + float thread_max = row_chunk[0]; +#pragma unroll + for (int ii = 1; ii < VPT; ++ii) + { + thread_max = max(thread_max, row_chunk[ii]); + } + +// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { + thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); + } + + // From this point, thread max in all the threads have the max within the row. + // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum. + float row_sum = 0; +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) + { + row_chunk[ii] = expf(row_chunk[ii] - thread_max); + row_sum += row_chunk[ii]; + } + +// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { + row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); + } + + // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables + // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to + // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row. + // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the + // argmax after computing the softmax. + const float reciprocal_row_sum = 1.f / row_sum; + +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) + { + row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; + } + + // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along + // with the max index. + int start_col = first_elt_read_by_thread; + static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + + for (int k_idx = 0; k_idx < k; ++k_idx) + { + // First, each thread does the local argmax + float max_val = row_chunk[0]; + int expert = start_col; +#pragma unroll + for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) + { +#pragma unroll + for (int ii = 0; ii < ELTS_PER_LDG; ++ii) + { + float val = row_chunk[ldg * ELTS_PER_LDG + ii]; + + // No check on the experts here since columns with the smallest index are processed first and only + // updated if > (not >=) + if (val > max_val) + { + max_val = val; + expert = col + ii; + } + } + } + +// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max. +// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can +// then blank out their max with -inf and the warp can run more iterations... +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { + float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); + int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); + + // We want lower indices to "win" in every thread so we break ties this way + if (other_max > max_val || (other_max == max_val && other_expert < expert)) + { + max_val = other_max; + expert = other_expert; + } + } + + // Write the max for this k iteration to global memory. + if (thread_group_idx == 0) + { + // Add a guard to ignore experts not included by this node + const bool node_uses_expert = expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + // The lead thread from each sub-group will write out the final results to global memory. (This will be a + // single) thread per row of the input/output matrices. + const int idx = k * thread_row + k_idx; + output[idx] = max_val; + indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS; + source_rows[idx] = k_idx * num_rows + thread_row; + } + + // Finally, we clear the value in the thread with the current max if there is another iteration to run. + if (k_idx + 1 < k) + { + const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; + const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW; + + // Only the thread in the group which produced the max will reset the "winning" value to -inf. + if (thread_group_idx == thread_to_clear_in_group) + { + const int offset_for_expert = expert % ELTS_PER_LDG; + // Safe to set to any negative value since row_chunk values must be between 0 and 1. + row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f; + } + } + } +} + +namespace detail +{ +// Constructs some constants needed to partition the work across threads at compile time. +template +struct TopkConstants +{ + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); + static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; + static constexpr int THREADS_PER_ROW = EXPERTS / VPT; + static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; +}; +} // namespace detail + +template +void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices, + int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream) +{ + static constexpr std::size_t MAX_BYTES_PER_LDG = 16; + + static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); + using Constants = detail::TopkConstants; + static constexpr int VPT = Constants::VPT; + static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; + const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; + + dim3 block_dim(WARP_SIZE, WARPS_PER_TB); + topkGatingSoftmax<<>>( + input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert); +} + +#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indicies, \ + token_expert_indices, num_tokens, topk, 0, num_experts, \ + stream); + +void topkGatingSoftmaxKernelLauncher( + const float* gating_output, + float* topk_weights, + int* topk_indicies, + int* token_expert_indices, + float* softmax_workspace, + const int num_tokens, + const int num_experts, + const int topk, + cudaStream_t stream) { + static constexpr int WARPS_PER_TB = 4; + switch (num_experts) { + case 1: + LAUNCH_SOFTMAX(1, WARPS_PER_TB); + break; + case 2: + LAUNCH_SOFTMAX(2, WARPS_PER_TB); + break; + case 4: + LAUNCH_SOFTMAX(4, WARPS_PER_TB); + break; + case 8: + LAUNCH_SOFTMAX(8, WARPS_PER_TB); + break; + case 16: + LAUNCH_SOFTMAX(16, WARPS_PER_TB); + break; + case 32: + LAUNCH_SOFTMAX(32, WARPS_PER_TB); + break; + case 64: + LAUNCH_SOFTMAX(64, WARPS_PER_TB); + break; + case 128: + LAUNCH_SOFTMAX(128, WARPS_PER_TB); + break; + case 256: + LAUNCH_SOFTMAX(256, WARPS_PER_TB); + break; + default: { + TORCH_CHECK(softmax_workspace != nullptr, + "softmax_workspace must be provided for num_experts that are not a power of 2."); + static constexpr int TPB = 256; + moeSoftmax<<>>( + gating_output, nullptr, softmax_workspace, num_experts); + moeTopK<<>>( + softmax_workspace, nullptr, topk_weights, topk_indicies, token_expert_indices, + num_experts, topk, 0, num_experts); + } + } +} + +void topk_softmax( + torch::Tensor& topk_weights, // [num_tokens, topk] + torch::Tensor& topk_indices, // [num_tokens, topk] + torch::Tensor& token_expert_indices, // [num_tokens, topk] + torch::Tensor& gating_output) // [num_tokens, num_experts] +{ + const int num_experts = gating_output.size(-1); + const int num_tokens = gating_output.numel() / num_experts; + const int topk = topk_weights.size(-1); + + const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + const bool needs_workspace = !is_pow_2 || num_experts > 256; + const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options()); + topkGatingSoftmaxKernelLauncher( + gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + stream); +} \ No newline at end of file diff --git a/awq_ext/vllm/topk_softmax_kernels.h b/awq_ext/vllm/topk_softmax_kernels.h new file mode 100644 index 0000000..7a7cbc6 --- /dev/null +++ b/awq_ext/vllm/topk_softmax_kernels.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +void topk_softmax( + torch::Tensor& topk_weights, + torch::Tensor& topk_indices, + torch::Tensor& token_expert_indices, + torch::Tensor& gating_output); \ No newline at end of file diff --git a/setup.py b/setup.py index 461878f..9fbae06 100644 --- a/setup.py +++ b/setup.py @@ -172,6 +172,9 @@ def get_extra_link_args(): "awq_ext/layernorm/layernorm.cu", "awq_ext/position_embedding/pos_encoding_kernels.cu", "awq_ext/quantization/gemv_cuda.cu", + "awq_ext/vllm/moe_alig_block.cu", + "awq_ext/vllm/activation.cu", + "awq_ext/vllm/topk_softmax_kernels.cu", ], extra_compile_args=extra_compile_args, )