Skip to content
This repository was archived by the owner on May 11, 2025. It is now read-only.

Commit fa58402

Browse files
MoE grouped gemm and fused topk_softmax (#8)
* Initial * group gemm * Fix install. Add topk_softmax kernels.
1 parent c448678 commit fa58402

File tree

10 files changed

+1013
-26
lines changed

10 files changed

+1013
-26
lines changed

awq_ext/pybind_awq.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,20 @@
44
#include "quantization/gemm_cuda.h"
55
#include "quantization/gemv_cuda.h"
66
#include "position_embedding/pos_encoding.h"
7+
#include "vllm/moe_alig_block.h"
8+
#include "vllm/activation.h"
9+
#include "vllm/topk_softmax_kernels.h"
710

811
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
912
{
1013
m.def("layernorm_forward_cuda", &layernorm_forward_cuda, "FasterTransformer layernorm kernel");
1114
m.def("gemm_forward_cuda", &gemm_forward_cuda, "Quantized GEMM kernel.");
15+
m.def("grouped_gemm_forward", &grouped_gemm_forward, "Quantized grouped GEMM kernel.");
1216
m.def("gemmv2_forward_cuda", &gemmv2_forward_cuda, "Quantized v2 GEMM kernel.");
1317
m.def("gemv_forward_cuda", &gemv_forward_cuda, "Quantized GEMV kernel.");
1418
m.def("rotary_embedding_neox", &rotary_embedding_neox, "Apply GPT-NeoX style rotary embedding to query and key");
1519
m.def("dequantize_weights_cuda", &dequantize_weights_cuda, "Dequantize weights.");
20+
m.def("moe_alig_block_size", &moe_alig_block_size, "Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
21+
m.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
22+
m.def("topk_softmax", &topk_softmax, "Computes fused topk and softmax operation.");
1623
}

awq_ext/quantization/gemm_cuda.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,18 @@
33
torch::Tensor gemm_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel,
44
torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters);
55

6+
torch::Tensor grouped_gemm_forward(
7+
torch::Tensor _in_feats,
8+
torch::Tensor _kernel,
9+
torch::Tensor _scaling_factors,
10+
torch::Tensor _zeros,
11+
torch::Tensor _topk_weights,
12+
torch::Tensor _sorted_token_ids_ptr,
13+
torch::Tensor _expert_ids_ptr,
14+
torch::Tensor _num_tokens_post_padded,
15+
bool mul_weights,
16+
int split_k_iters);
17+
618
torch::Tensor gemmv2_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel,
719
torch::Tensor _scaling_factors, torch::Tensor _zeros, int group_size, int split_k_iters);
820

awq_ext/quantization/gemm_cuda_gen.cu

Lines changed: 331 additions & 26 deletions
Large diffs are not rendered by default.

awq_ext/vllm/activation.cu

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#include <ATen/cuda/CUDAContext.h>
2+
#include <torch/extension.h>
3+
#include <c10/cuda/CUDAGuard.h>
4+
5+
#define VLLM_LDG(arg) *(arg)
6+
7+
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
8+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
9+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
10+
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
11+
12+
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
13+
AT_DISPATCH_SWITCH( \
14+
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
15+
16+
template<typename T>
17+
__device__ __forceinline__ T silu(const T& x) {
18+
// x * sigmoid(x)
19+
return (T) (((float) x) / (1.0f + expf((float) -x)));
20+
}
21+
22+
template<typename scalar_t>
23+
__global__ void silu_and_mul_kernel(
24+
scalar_t* __restrict__ out, // [..., d]
25+
const scalar_t* __restrict__ input, // [..., 2, d]
26+
const int d) {
27+
const int64_t token_idx = blockIdx.x;
28+
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
29+
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
30+
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
31+
out[token_idx * d + idx] = silu(x) * y;
32+
}
33+
}
34+
35+
36+
void silu_and_mul(
37+
torch::Tensor& out, // [..., d]
38+
torch::Tensor& input) // [..., 2 * d]
39+
{
40+
int64_t num_tokens = input.numel() / input.size(-1);
41+
int d = input.size(-1) / 2;
42+
43+
dim3 grid(num_tokens);
44+
dim3 block(std::min(d, 1024));
45+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
46+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
47+
VLLM_DISPATCH_FLOATING_TYPES(
48+
input.scalar_type(),
49+
"silu_and_mul_kernel",
50+
[&] {
51+
silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
52+
out.data_ptr<scalar_t>(),
53+
input.data_ptr<scalar_t>(),
54+
d);
55+
});
56+
}

awq_ext/vllm/activation.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
void silu_and_mul(
2+
torch::Tensor& out,
3+
torch::Tensor& input);

awq_ext/vllm/moe_alig_block.cu

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#include <torch/extension.h>
2+
#include <ATen/cuda/CUDAContext.h>
3+
4+
#include <ATen/ATen.h>
5+
#include <THC/THCAtomics.cuh>
6+
7+
const static size_t NUM_MAX_EXPERTS = 64;
8+
9+
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
10+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
11+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
12+
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
13+
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
14+
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
15+
16+
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
17+
AT_DISPATCH_SWITCH( \
18+
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
19+
20+
template <typename scalar_t>
21+
__global__ void moe_alig_block_size_kernel(scalar_t *__restrict__ topk_ids,
22+
int32_t *sorted_token_ids,
23+
int32_t *expert_ids,
24+
int32_t *total_tokens_post_pad,
25+
int32_t num_experts,
26+
int32_t block_size,
27+
size_t numel) {
28+
const size_t tokens_per_thread = ((numel + blockDim.x - 1) / blockDim.x);
29+
const size_t start_idx = threadIdx.x * tokens_per_thread;
30+
__shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS];
31+
__shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1];
32+
for(int i = 0;i < num_experts;i++){
33+
tokens_cnts[threadIdx.x + 1][i] = 0;
34+
}
35+
36+
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
37+
++tokens_cnts[threadIdx.x + 1][topk_ids[i]];
38+
}
39+
40+
__syncthreads();
41+
42+
tokens_cnts[0][threadIdx.x] = 0;
43+
for(int i=1;i<=blockDim.x;++i){
44+
tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x];
45+
}
46+
47+
__syncthreads();
48+
49+
if(threadIdx.x ==0){
50+
cumsum[0] = 0;
51+
for(int i=1;i<=num_experts;++i){
52+
cumsum[i] = cumsum[i-1] + (tokens_cnts[blockDim.x][i - 1] + block_size - 1) / block_size * block_size;
53+
}
54+
*total_tokens_post_pad = cumsum[num_experts];
55+
}
56+
57+
__syncthreads();
58+
59+
for(int i= cumsum[threadIdx.x];i<cumsum[threadIdx.x + 1];i += block_size){
60+
expert_ids[i / block_size] = threadIdx.x;
61+
}
62+
63+
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
64+
int32_t expert_id = topk_ids[i];
65+
int32_t rank_post_pad = tokens_cnts[threadIdx.x][expert_id] + cumsum[expert_id];
66+
sorted_token_ids[rank_post_pad] = i;
67+
++tokens_cnts[threadIdx.x][expert_id];
68+
}
69+
}
70+
71+
void moe_alig_block_size(
72+
torch::Tensor topk_ids,
73+
int num_experts,
74+
int block_size,
75+
torch::Tensor sorted_token_ids,
76+
torch::Tensor experts_ids,
77+
torch::Tensor num_tokens_post_pad) {
78+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
79+
assert(num_experts <= NUM_MAX_EXPERTS);
80+
VLLM_DISPATCH_INTEGRAL_TYPES(
81+
topk_ids.scalar_type(), "moe_alig_block_size_kernel", [&] {
82+
moe_alig_block_size_kernel<scalar_t><<<1, num_experts, 0, stream>>>(
83+
topk_ids.data_ptr<scalar_t>(),
84+
sorted_token_ids.data_ptr<int32_t>(),
85+
experts_ids.data_ptr<int32_t>(),
86+
num_tokens_post_pad.data_ptr<int32_t>(),
87+
num_experts,
88+
block_size,
89+
topk_ids.numel());
90+
});
91+
}

awq_ext/vllm/moe_alig_block.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
void moe_alig_block_size(
2+
torch::Tensor topk_ids,
3+
int num_experts,
4+
int block_size,
5+
torch::Tensor sorted_token_ids,
6+
torch::Tensor experts_ids,
7+
torch::Tensor num_tokens_post_pad
8+
);

0 commit comments

Comments
 (0)