diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py index 4ed690090144..04d2205aa372 100644 --- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py +++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py @@ -8,12 +8,13 @@ import torch from transformers import AutoConfig -from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( +from vllm.model_executor.layers.fused_moe.fused_moe import * +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( _moe_permute, _moe_unpermute_and_reduce, + moe_permute, + moe_unpermute, ) -from vllm.model_executor.layers.fused_moe.fused_moe import * -from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import * from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser @@ -63,18 +64,19 @@ def prepare(i: int): def run(): if use_customized_permute: - (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = ( - moe_permute( - qhidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - token_expert_indices=token_expert_indices, - topk=topk, - n_expert=num_experts, - n_local_expert=num_experts, - expert_map=None, - align_block_size=align_block_size, - ) + ( + permuted_hidden_states, + a1q_scale, + first_token_off, + inv_perm_idx, + m_indices, + ) = moe_permute( + qhidden_states, + a1q_scale=None, + topk_ids=topk_ids, + n_expert=num_experts, + expert_map=None, + align_block_size=align_block_size, ) else: ( @@ -150,18 +152,19 @@ def benchmark_unpermute( def prepare(): if use_customized_permute: - (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = ( - moe_permute( - qhidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - token_expert_indices=token_expert_indices, - topk=topk, - n_expert=num_experts, - n_local_expert=num_experts, - expert_map=None, - align_block_size=align_block_size, - ) + ( + permuted_hidden_states, + a1q_scale, + first_token_off, + inv_perm_idx, + m_indices, + ) = moe_permute( + qhidden_states, + a1q_scale=None, + topk_ids=topk_ids, + n_expert=num_experts, + expert_map=None, + align_block_size=align_block_size, ) # convert to fp16/bf16 as gemm output return ( @@ -191,16 +194,19 @@ def prepare(): def run(input: tuple): if use_customized_permute: - (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = input + ( + permuted_hidden_states, + first_token_off, + inv_perm_idx, + m_indices, + ) = input + output = torch.empty_like(hidden_states) moe_unpermute( + output, permuted_hidden_states, topk_weights, - topk_ids, inv_perm_idx, first_token_off, - topk, - num_experts, - num_experts, ) else: ( @@ -211,7 +217,11 @@ def run(input: tuple): inv_perm, ) = input _moe_unpermute_and_reduce( - output_hidden_states, permuted_hidden_states, inv_perm, topk_weights + output_hidden_states, + permuted_hidden_states, + inv_perm, + topk_weights, + True, ) # JIT compilation & warmup diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index a77471a7f207..2922352a3f7c 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -10,32 +10,28 @@ void moe_permute( const torch::Tensor& input, // [n_token, hidden] - const torch::Tensor& topk_weights, //[n_token, topk] - torch::Tensor& topk_ids, // [n_token, topk] + const torch::Tensor& topk_ids, // [n_token, topk] const torch::Tensor& token_expert_indices, // [n_token, topk] const std::optional& expert_map, // [n_expert] int64_t n_expert, int64_t n_local_expert, int64_t topk, const std::optional& align_block_size, - torch::Tensor& - permuted_input, // [topk * n_token/align_block_size_m, hidden] + torch::Tensor& permuted_input, // [permuted_size, hidden] torch::Tensor& expert_first_token_offset, // [n_local_expert + 1] - torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk] + torch::Tensor& inv_permuted_idx, // [n_token, topk] + torch::Tensor& permuted_idx, // [permute_size] torch::Tensor& m_indices) { // [align_expand_m] - TORCH_CHECK(topk_weights.scalar_type() == at::ScalarType::Float, - "topk_weights must be float32"); TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long, "expert_first_token_offset must be int64"); TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, "topk_ids must be int32"); TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int, "token_expert_indices must be int32"); - TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int, - "src_row_id2dst_row_id_map must be int32"); + TORCH_CHECK(inv_permuted_idx.scalar_type() == at::ScalarType::Int, + "inv_permuted_idx must be int32"); TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1, "expert_first_token_offset shape != n_local_expert+1") - TORCH_CHECK( - src_row_id2dst_row_id_map.sizes() == token_expert_indices.sizes(), - "token_expert_indices shape must be same as src_row_id2dst_row_id_map"); + TORCH_CHECK(inv_permuted_idx.sizes() == token_expert_indices.sizes(), + "token_expert_indices shape must be same as inv_permuted_idx"); auto n_token = input.sizes()[0]; auto n_hidden = input.sizes()[1]; auto align_block_size_value = @@ -46,8 +42,9 @@ void moe_permute( auto sort_workspace = torch::empty( {sorter_size}, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); + auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess auto permuted_experts_id = torch::empty_like(topk_ids); - auto dst_row_id2src_row_id_map = torch::empty_like(src_row_id2dst_row_id_map); + auto sorted_row_idx = torch::empty_like(inv_permuted_idx); auto align_expert_first_token_offset = torch::zeros_like(expert_first_token_offset); @@ -67,24 +64,22 @@ void moe_permute( const int* expert_map_ptr = get_ptr(expert_map.value()); valid_num_ptr = get_ptr(expert_first_token_offset) + n_local_expert; - preprocessTopkIdLauncher(get_ptr(topk_ids), n_token * topk, + preprocessTopkIdLauncher(get_ptr(copy_topk_ids), n_token * topk, expert_map_ptr, n_expert, stream); } // expert sort topk expert id and scan expert id get expert_first_token_offset - sortAndScanExpert(get_ptr(topk_ids), get_ptr(token_expert_indices), - get_ptr(permuted_experts_id), - get_ptr(dst_row_id2src_row_id_map), - get_ptr(expert_first_token_offset), n_token, - n_expert, n_local_expert, topk, sorter, - get_ptr(sort_workspace), stream); + sortAndScanExpert( + get_ptr(copy_topk_ids), get_ptr(token_expert_indices), + get_ptr(permuted_experts_id), get_ptr(sorted_row_idx), + get_ptr(expert_first_token_offset), n_token, n_expert, + n_local_expert, topk, sorter, get_ptr(sort_workspace), stream); // dispatch expandInputRowsKernelLauncher MOE_DISPATCH(input.scalar_type(), [&] { expandInputRowsKernelLauncher( get_ptr(input), get_ptr(permuted_input), - get_ptr(topk_weights), get_ptr(permuted_experts_id), - get_ptr(dst_row_id2src_row_id_map), - get_ptr(src_row_id2dst_row_id_map), + get_ptr(permuted_experts_id), get_ptr(sorted_row_idx), + get_ptr(inv_permuted_idx), get_ptr(permuted_idx), get_ptr(expert_first_token_offset), n_token, valid_num_ptr, n_hidden, topk, n_local_expert, align_block_size_value, stream); }); @@ -101,32 +96,34 @@ void moe_permute( } void moe_unpermute( - const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden] - const torch::Tensor& topk_weights, //[n_token, topk] - const torch::Tensor& topk_ids, // [n_token, topk] - const torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk] - const torch::Tensor& expert_first_token_offset, // [n_local_expert+1] - int64_t n_expert, int64_t n_local_expert, int64_t topk, + const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden] + const torch::Tensor& topk_weights, // [n_token, topk] + const torch::Tensor& inv_permuted_idx, // [n_token, topk] + const std::optional& + expert_first_token_offset, // [n_local_expert+1] + int64_t topk, torch::Tensor& hidden_states // [n_token, hidden] ) { - TORCH_CHECK(src_row_id2dst_row_id_map.sizes() == topk_ids.sizes(), - "topk_ids shape must be same as src_row_id2dst_row_id_map"); - TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, - "topk_ids must be int32"); TORCH_CHECK( permuted_hidden_states.scalar_type() == hidden_states.scalar_type(), - "topk_ids dtype must be same as src_row_id2dst_row_id_map"); + "permuted_hidden_states dtype must be same as hidden_states"); auto n_token = hidden_states.size(0); auto n_hidden = hidden_states.size(1); auto stream = at::cuda::getCurrentCUDAStream().stream(); - const int64_t* valid_ptr = - get_ptr(expert_first_token_offset) + n_local_expert; + + int64_t const* valid_ptr = nullptr; + if (expert_first_token_offset.has_value()) { + int n_local_expert = expert_first_token_offset.value().size(0) - 1; + valid_ptr = + get_ptr(expert_first_token_offset.value()) + n_local_expert; + } + MOE_DISPATCH(hidden_states.scalar_type(), [&] { finalizeMoeRoutingKernelLauncher( get_ptr(permuted_hidden_states), get_ptr(hidden_states), get_ptr(topk_weights), - get_ptr(src_row_id2dst_row_id_map), get_ptr(topk_ids), - n_token, n_hidden, topk, valid_ptr, stream); + get_ptr(inv_permuted_idx), n_token, n_hidden, topk, valid_ptr, + stream); }); } diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu index de2c153882d9..2271c1bc75b1 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu @@ -177,7 +177,7 @@ __global__ void getMIndicesKernel(int64_t* expert_first_token_offset, int tidx = threadIdx.x; extern __shared__ int64_t smem_expert_first_token_offset[]; for (int i = tidx; i <= num_local_expert; i += blockDim.x) { - smem_expert_first_token_offset[tidx] = __ldg(expert_first_token_offset + i); + smem_expert_first_token_offset[i] = __ldg(expert_first_token_offset + i); } __syncthreads(); auto last_token_offset = smem_expert_first_token_offset[eidx + 1]; diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h index 43c29721cd16..108091efbefa 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h @@ -57,31 +57,19 @@ void sortAndScanExpert(int* expert_for_source_row, const int* source_rows, template void expandInputRowsKernelLauncher( - T const* unpermuted_input, T* permuted_output, - const float* unpermuted_scales, int* sorted_experts, + T const* unpermuted_input, T* permuted_output, int* sorted_experts, int const* expanded_dest_row_to_expanded_source_row, - int* expanded_source_row_to_expanded_dest_row, + int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, int64_t* expert_first_token_offset, int64_t const num_rows, int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, int num_local_experts, const int& align_block_size, cudaStream_t stream); -// Final kernel to unpermute and scale -// This kernel unpermutes the original data, does the k-way reduction and -// performs the final skip connection. -template -__global__ void finalizeMoeRoutingKernel( - T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, - float const* scales, int const* expanded_source_row_to_expanded_dest_row, - int const* expert_for_source_row, int64_t const orig_cols, int64_t const k, - int64_t const* num_valid_ptr); - template void finalizeMoeRoutingKernelLauncher( T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, float const* scales, int const* expanded_source_row_to_expanded_dest_row, - int const* expert_for_source_row, int64_t const num_rows, - int64_t const cols, int64_t const k, int64_t const* num_valid_ptr, - cudaStream_t stream); + int64_t const num_rows, int64_t const cols, int64_t const k, + int64_t const* num_valid_ptr, cudaStream_t stream); void preprocessTopkIdLauncher(int* topk_id_ptr, int size, const int* expert_map_ptr, int num_experts, diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl index ad0d390665a0..449243b92a28 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl @@ -2,10 +2,9 @@ template __global__ void expandInputRowsKernel( - T const* unpermuted_input, T* permuted_output, - const float* unpermuted_scales, int* sorted_experts, + T const* unpermuted_input, T* permuted_output, int* sorted_experts, int const* expanded_dest_row_to_expanded_source_row, - int* expanded_source_row_to_expanded_dest_row, + int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, int64_t* expert_first_token_offset, int64_t const num_rows, int64_t const* num_dest_rows, int64_t const cols, int64_t k, int num_local_experts, int align_block_size) { @@ -54,6 +53,10 @@ __global__ void expandInputRowsKernel( assert(expanded_dest_row <= INT32_MAX); expanded_source_row_to_expanded_dest_row[expanded_source_row] = static_cast(expanded_dest_row); + // skip non local expert token + if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) { + permuted_idx[expanded_dest_row] = expanded_source_row; + } } if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) { @@ -62,7 +65,7 @@ __global__ void expandInputRowsKernel( using DataElem = cutlass::Array; // Duplicate and permute rows - int64_t const source_row = expanded_source_row % num_rows; + int64_t const source_row = expanded_source_row / k; auto const* source_row_ptr = reinterpret_cast(unpermuted_input + source_row * cols); @@ -82,10 +85,9 @@ __global__ void expandInputRowsKernel( template void expandInputRowsKernelLauncher( - T const* unpermuted_input, T* permuted_output, - const float* unpermuted_scales, int* sorted_experts, + T const* unpermuted_input, T* permuted_output, int* sorted_experts, int const* expanded_dest_row_to_expanded_source_row, - int* expanded_source_row_to_expanded_dest_row, + int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, int64_t* expert_first_token_offset, int64_t const num_rows, int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, int num_local_experts, const int& align_block_size, cudaStream_t stream) { @@ -105,11 +107,11 @@ void expandInputRowsKernelLauncher( int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1); func<<>>( - unpermuted_input, permuted_output, unpermuted_scales, sorted_experts, + unpermuted_input, permuted_output, sorted_experts, expanded_dest_row_to_expanded_source_row, - expanded_source_row_to_expanded_dest_row, expert_first_token_offset, - num_rows, num_valid_tokens_ptr, cols, k, num_local_experts, - align_block_size); + expanded_source_row_to_expanded_dest_row, permuted_idx, + expert_first_token_offset, num_rows, num_valid_tokens_ptr, cols, k, + num_local_experts, align_block_size); } template @@ -128,11 +130,9 @@ template __global__ void finalizeMoeRoutingKernel( T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, float const* scales, int const* expanded_source_row_to_expanded_dest_row, - int const* expert_for_source_row, int64_t const orig_cols, int64_t const k, - int64_t const* num_valid_ptr) { + int64_t const orig_cols, int64_t const k, int64_t const* num_valid_ptr) { assert(orig_cols % 4 == 0); int64_t const original_row = blockIdx.x; - int64_t const num_rows = gridDim.x; auto const offset = original_row * orig_cols; OutputType* reduced_row_ptr = reduced_unpermuted_output + offset; int64_t const num_valid = *num_valid_ptr; @@ -159,14 +159,13 @@ __global__ void finalizeMoeRoutingKernel( ComputeElem thread_output; thread_output.fill(0); for (int k_idx = 0; k_idx < k; ++k_idx) { - int64_t const expanded_original_row = original_row + k_idx * num_rows; + int64_t const expanded_original_row = original_row * k + k_idx; int64_t const expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row]; int64_t const k_offset = original_row * k + k_idx; float const row_scale = scales[k_offset]; - // Check after row_rescale has accumulated if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) { continue; } @@ -189,9 +188,8 @@ template void finalizeMoeRoutingKernelLauncher( T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, float const* scales, int const* expanded_source_row_to_expanded_dest_row, - int const* expert_for_source_row, int64_t const num_rows, - int64_t const cols, int64_t const k, int64_t const* num_valid_ptr, - cudaStream_t stream) { + int64_t const num_rows, int64_t const cols, int64_t const k, + int64_t const* num_valid_ptr, cudaStream_t stream) { int64_t const blocks = num_rows; int64_t const threads = 256; bool const check_finished = num_valid_ptr != nullptr; @@ -201,6 +199,5 @@ void finalizeMoeRoutingKernelLauncher( auto* const kernel = func_map[check_finished]; kernel<<>>( expanded_permuted_rows, reduced_unpermuted_output, scales, - expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k, - num_valid_ptr); + expanded_source_row_to_expanded_dest_row, cols, k, num_valid_ptr); } diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 97df311d0440..d96e082f6ef1 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -56,18 +56,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " -> Tensor"); m.def( - "moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids," + "moe_permute(Tensor input, Tensor topk_ids," "Tensor token_expert_indices, Tensor? expert_map, int n_expert," "int n_local_expert," "int topk, int? align_block_size,Tensor! permuted_input, Tensor! " - "expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! " - "m_indices)->()"); + "expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! " + "permuted_idx, Tensor! m_indices)->()"); m.def( "moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights," - "Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor " - "expert_first_token_offset, int n_expert, int n_local_expert,int " - "topk, Tensor! hidden_states)->()"); + "Tensor inv_permuted_idx, Tensor? expert_first_token_offset, " + "int topk, Tensor! hidden_states)->()"); m.def("moe_permute_unpermute_supported() -> bool"); m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported); diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index 7cc83b512c8b..8d215a0cbeed 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -17,28 +17,34 @@ moe_permute, moe_permute_unpermute_supported, moe_unpermute) from vllm.platforms import current_platform -NUM_EXPERTS = [16, 64] +NUM_EXPERTS = [16, 64, 256] TOP_KS = [2, 4, 6, 8] EP_SIZE = [1, 4, 16] current_platform.seed_everything(0) -def torch_permute(hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - token_expert_indices: torch.Tensor, - topk: int, - n_expert: int, - n_local_expert: int, - start_expert: int, - expert_map: Optional[torch.Tensor] = None, - align_block_size: Optional[int] = None, - fill_invalid_expert: int = -1) -> list[torch.Tensor]: +def torch_permute( + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + # token_expert_indices: torch.Tensor, + topk: int, + n_expert: int, + n_local_expert: int, + start_expert: int, + expert_map: Optional[torch.Tensor] = None, + align_block_size: Optional[int] = None, + fill_invalid_expert: int = -1) -> list[torch.Tensor]: n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1] if expert_map is not None: is_local_expert = (expert_map[topk_ids] != -1) not_local_expert = (expert_map[topk_ids] == -1) topk_ids = is_local_expert * ( topk_ids - start_expert) + not_local_expert * (topk_ids + n_expert) + token_expert_indices = torch.arange(0, + n_token * topk, + dtype=torch.int32, + device=hidden_states.device).reshape( + (n_token, topk)) sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(), stable=True) @@ -59,8 +65,8 @@ def torch_permute(hidden_states: torch.Tensor, valid_row_idx = [] if align_block_size is None: - permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map % - n_token, ...] + permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // + topk, ...] permuted_row_size = permuted_hidden_states.shape[0] m_indices = torch.empty(permuted_row_size, device="cuda", @@ -73,14 +79,21 @@ def torch_permute(hidden_states: torch.Tensor, 0, n_token * topk, device="cuda", dtype=torch.int32)[src2dst_idx].reshape((n_token, topk)) valid_row_idx += [i for i in range(expert_first_token_offset[-1])] + dst_row_id2src_row_id_map[ + expert_first_token_offset[-1]:] = n_token * topk return [ permuted_hidden_states, expert_first_token_offset, - src_row_id2dst_row_id_map, m_indices, valid_row_idx + src_row_id2dst_row_id_map, dst_row_id2src_row_id_map, m_indices, + valid_row_idx ] else: permuted_row_size = (topk * n_token + n_expert * (align_block_size - 1) + align_block_size - 1) // align_block_size * align_block_size + permuted_idx = torch.full((permuted_row_size, ), + n_token * topk, + dtype=torch.int32, + device=hidden_states.device) permuted_hidden_states = torch.empty((permuted_row_size, n_hidden), device="cuda", dtype=hidden_states.dtype) @@ -105,13 +118,16 @@ def torch_permute(hidden_states: torch.Tensor, align_first_token_offset = align_expert_first_token_offset[i - 1] align_last_token_offset = align_expert_first_token_offset[i] dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[ - first_token_offset:first_token_offset + - n_token_in_expert] % n_token + first_token_offset:first_token_offset + n_token_in_expert] # store token in current expert with align_first_token_offset permuted_hidden_states[align_first_token_offset:\ align_first_token_offset+n_token_in_expert,\ ...] = hidden_states[\ - dst_row_id2src_row_id_in_expert, ...] + dst_row_id2src_row_id_in_expert // topk,\ + ...] + permuted_idx[align_first_token_offset:\ + align_first_token_offset+\ + n_token_in_expert] = dst_row_id2src_row_id_in_expert # set current expert m_indices m_indices[align_first_token_offset:align_last_token_offset] = i - 1 valid_row_idx += [ @@ -135,7 +151,7 @@ def torch_permute(hidden_states: torch.Tensor, src2dst_idx].reshape((n_token, topk)) return [ permuted_hidden_states, align_expert_first_token_offset, - align_src_row_id2dst_row_id, m_indices, valid_row_idx + align_src_row_id2dst_row_id, permuted_idx, m_indices, valid_row_idx ] @@ -146,15 +162,18 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor, valid_row_idx: torch.Tensor, topk: int, n_expert: int) -> torch.Tensor: # ignore invalid row + n_hidden = permuted_hidden_states.shape[1] mask = torch.zeros(permuted_hidden_states.shape[0], dtype=bool, device="cuda") mask[valid_row_idx] = True permuted_hidden_states[~mask] = 0 - idx = src_row_id2dst_row_id_map.flatten()[ - token_expert_indices.flatten()].reshape(token_expert_indices.shape) - output = permuted_hidden_states[idx, ...] * topk_weights[..., None] - output = output.sum(dim=1).to(permuted_hidden_states.dtype) + + permuted_hidden_states = permuted_hidden_states[ + src_row_id2dst_row_id_map.flatten(), ...] + permuted_hidden_states = permuted_hidden_states.view(-1, topk, n_hidden) + output = (permuted_hidden_states * topk_weights.unsqueeze(2)).sum(1).to( + permuted_hidden_states.dtype) return output @@ -184,43 +203,56 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype) topk_weights, topk_ids, token_expert_indices = fused_topk( hidden_states, gating_output, topk, False) - gold0, gold1, gold2, gold3, valid_row_idx = torch_permute( - hidden_states, - topk_ids, - token_expert_indices, - topk, - n_expert, - n_local_expert, - start_expert, - expert_map=expert_map, - align_block_size=align_block_size, - fill_invalid_expert=fill_invalid_expert) - - result0, result1, result2, result3 = moe_permute( - hidden_states, topk_weights, topk_ids, token_expert_indices, topk, - n_expert, n_local_expert, expert_map, align_block_size, - fill_invalid_expert) + (gold_permuted_hidden_states, gold_expert_first_token_offset, + gold_inv_permuted_idx, gold_permuted_idx, gold_m_indices, + valid_row_idx) = torch_permute( + hidden_states, + topk_ids, + # token_expert_indices, + topk, + n_expert, + n_local_expert, + start_expert, + expert_map=expert_map, + align_block_size=align_block_size, + fill_invalid_expert=fill_invalid_expert) + + (permuted_hidden_states, _, expert_first_token_offset, inv_permuted_idx, + m_indices) = moe_permute(hidden_states=hidden_states, + a1q_scale=None, + topk_ids=topk_ids, + n_expert=n_expert, + n_local_expert=n_local_expert, + expert_map=expert_map, + align_block_size=align_block_size, + fill_invalid_expert=fill_invalid_expert) # check expert_first_token_offset - torch.testing.assert_close(gold1, result1, atol=0, rtol=0) + torch.testing.assert_close(gold_expert_first_token_offset, + expert_first_token_offset, + atol=0, + rtol=0) # check src_row_id2dst_row_id_map - torch.testing.assert_close(gold2, result2, atol=0, rtol=0) + torch.testing.assert_close(gold_inv_permuted_idx.flatten(), + inv_permuted_idx, + atol=0, + rtol=0) # check mindice - torch.testing.assert_close(gold3, result3, atol=0, rtol=0) + torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0) # check permuted_hidden_states, only valid token - torch.testing.assert_close(gold0[valid_row_idx], - result0[valid_row_idx], + torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx], + permuted_hidden_states[valid_row_idx], atol=0, rtol=0) - # add a random tensor to simulate group gemm - result0 = 0.5 * result0 + torch.randn_like(result0) + result0 = 0.5 * permuted_hidden_states + torch.randn_like( + permuted_hidden_states) + result4 = torch.empty_like(hidden_states) + moe_unpermute(result4, result0, topk_weights, inv_permuted_idx, + expert_first_token_offset) - result4 = moe_unpermute(result0, topk_weights, topk_ids, result2, result1, - topk, n_expert, n_local_expert) gold4 = torch_unpermute(result0, topk_weights, topk_ids, - token_expert_indices, result2, valid_row_idx, topk, - n_local_expert) - + token_expert_indices, inv_permuted_idx, + valid_row_idx, topk, n_local_expert) # check unpermuted hidden torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0) diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index 20ee0d9f780a..d9059f50b445 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -76,43 +76,43 @@ def _moe_unpermute_and_reduce( def moe_permute( hidden_states: torch.Tensor, - topk_weights: torch.Tensor, + a1q_scale: Optional[torch.Tensor], topk_ids: torch.Tensor, - token_expert_indices: torch.Tensor, - topk: int, n_expert: int, - n_local_expert: int, + n_local_expert: int = -1, expert_map: Optional[torch.Tensor] = None, align_block_size: Optional[int] = None, fill_invalid_expert: int = -1 -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, + torch.Tensor]: """ This function expands and permutes activation to gather uncontinuous tokens for each expert. Parameters: - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - topk_weights (torch.Tensor): topk expert route weight for each token. + - a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states - topk_ids (torch.Tensor): topk expert route id for each token. - - token_expert_indices (torch.Tensor): indice for expanded hidden. - - topk (int): The number of top-k experts to select. - n_expert (int): The number of expert. - n_local_expert (int): The number of expert in current EP rank. - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert + from the global expert space to the local expert space of the expert parallel shard. - align_block_size (Optional[int]): align group gemm block size for deepgemm - fill_invalid_expert(int): fill expert id in m_indices for invalid expert to workaround DeepGemm unsupported -1 in m_indices Returns: - permuted_hidden_states (torch.Tensor): permuted activation. + - a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states - expert_first_token_offset (torch.Tensor): offset of the first token of each expert for standard grouped gemm. if enable 'align_block_size' expert_first_token_offset will align up to 'align_block_size'. - - src_row_id2dst_row_id_map (torch.Tensor): idx map for moe_unpermute. + - inv_permuted_idx (torch.Tensor): idx map for moe_unpermute. + - permuted_idx (torch.Tensor): idx map from hidden to permuted_hidden. - m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records the group which the j-th row of the LHS belong to.` """ n_token, n_hidden = hidden_states.size() + topk = topk_ids.size(1) assert (n_hidden * hidden_states.element_size() ) % 16 == 0, "permue kernel need hidden dim align to 16B" permuted_row_size = n_token * topk @@ -120,12 +120,19 @@ def moe_permute( permuted_row_size = (permuted_row_size + n_expert * (align_block_size - 1) + align_block_size - 1) // align_block_size * align_block_size - + if n_local_expert == -1: + n_local_expert = n_expert permuted_hidden_states = torch.empty( (permuted_row_size, n_hidden), dtype=hidden_states.dtype, device=hidden_states.device, ) + token_expert_indices = torch.arange(0, + n_token * topk, + dtype=torch.int32, + device=hidden_states.device).reshape( + (n_token, topk)) + m_indices = torch.full((permuted_row_size, ), fill_invalid_expert, dtype=torch.int32, @@ -133,57 +140,54 @@ def moe_permute( expert_first_token_offset = torch.empty(n_local_expert + 1, dtype=torch.int64, device=hidden_states.device) - src_row_id2dst_row_id_map = torch.empty((n_token, topk), - dtype=torch.int32, - device=hidden_states.device) - torch.ops._moe_C.moe_permute(hidden_states, topk_weights, topk_ids, - token_expert_indices, expert_map, n_expert, - n_local_expert, topk, align_block_size, - permuted_hidden_states, - expert_first_token_offset, - src_row_id2dst_row_id_map, m_indices) - return (permuted_hidden_states, expert_first_token_offset, - src_row_id2dst_row_id_map, m_indices) + permuted_idx = torch.full((permuted_row_size, ), + n_token * topk, + dtype=torch.int32, + device=hidden_states.device) + inv_permuted_idx = torch.empty((n_token, topk), + dtype=torch.int32, + device=hidden_states.device) + topk_ids = topk_ids.to(torch.int32) + torch.ops._moe_C.moe_permute(hidden_states, topk_ids, token_expert_indices, + expert_map, n_expert, n_local_expert, topk, + align_block_size, permuted_hidden_states, + expert_first_token_offset, inv_permuted_idx, + permuted_idx, m_indices) + if a1q_scale is not None: + a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) // + topk] + return (permuted_hidden_states, a1q_scale, expert_first_token_offset, + inv_permuted_idx.flatten(), m_indices) def moe_unpermute( + out: torch.Tensor, permuted_hidden_states: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - src_row_id2dst_row_id_map: torch.Tensor, - expert_first_token_offset: torch.Tensor, - topk: int, - n_expert: int, - n_local_expert: int, -) -> torch.Tensor: + inv_permuted_idx: torch.Tensor, + expert_first_token_offset: Optional[torch.Tensor] = None, +) -> None: """ This function expands and permutes activation to gathering uncontinuous tokens for each expert. Parameters: + - out (torch.Tensor): output tensor - permuted_hidden_states (torch.Tensor): permuted activation. - topk_weights (torch.Tensor): topk expert route weight for each token. - - topk_ids (torch.Tensor): topk expert route id for each token. - - expert_first_token_offset (torch.Tensor): offset of the first token - of each expert for grouped gemm. - - topk (int): The number of top-k experts to select. - - n_expert (int): The number of expert. - - n_local_expert (int): The number of expert in current EP rank. + - inv_permuted_idx (torch.Tensor): row idx map for moe_unpermute. + - expert_first_token_offset (Optional[torch.Tensor]): offset of the first + token of each expert for grouped gemm. Returns: - hidden_states (torch.Tensor): The reduced and unpermuted activation tensor. """ - n_token, n_hidden = topk_weights.size(0), permuted_hidden_states.size(-1) + topk = topk_weights.size(1) + n_hidden = permuted_hidden_states.size(-1) assert (n_hidden * permuted_hidden_states.element_size() ) % 16 == 0, "unpermue kernel need hidden dim align to 16B" - hidden_states = torch.empty((n_token, n_hidden), - dtype=permuted_hidden_states.dtype, - device=permuted_hidden_states.device) - torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights, - topk_ids, src_row_id2dst_row_id_map, - expert_first_token_offset, n_expert, - n_local_expert, topk, hidden_states) - return hidden_states + inv_permuted_idx, expert_first_token_offset, + topk, out) def moe_permute_unpermute_supported():