diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl index 42441800fb11..ad0d390665a0 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl @@ -20,7 +20,6 @@ __global__ void expandInputRowsKernel( int expert_id = sorted_experts[expanded_dest_row]; extern __shared__ int64_t smem_expert_first_token_offset[]; - int64_t align_expanded_row_accumulate = 0; if constexpr (ALIGN_BLOCK_SIZE) { // load g2s for (int idx = threadIdx.x; idx < num_local_experts + 1; @@ -63,7 +62,6 @@ __global__ void expandInputRowsKernel( using DataElem = cutlass::Array; // Duplicate and permute rows - int64_t const source_k_rank = expanded_source_row / num_rows; int64_t const source_row = expanded_source_row % num_rows; auto const* source_row_ptr = @@ -160,7 +158,6 @@ __global__ void finalizeMoeRoutingKernel( elem_index += stride) { ComputeElem thread_output; thread_output.fill(0); - float row_rescale{0.f}; for (int k_idx = 0; k_idx < k; ++k_idx) { int64_t const expanded_original_row = original_row + k_idx * num_rows; int64_t const expanded_permuted_row = @@ -177,8 +174,6 @@ __global__ void finalizeMoeRoutingKernel( auto const* expanded_permuted_rows_row_ptr = expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col; - int64_t const expert_idx = expert_for_source_row[k_offset]; - ComputeElem expert_result = arrayConvert( expanded_permuted_rows_row_ptr[elem_index]); thread_output = thread_output + row_scale * (expert_result);