@@ -28,8 +28,8 @@ __global__ void moe_lora_align_sum_kernel(
2828 int64_t block_size, int num_experts, int max_loras, size_t numel,
2929 int max_num_tokens_padded, int max_num_m_blocks,
3030 int32_t * __restrict__ sorted_token_ids, int32_t * __restrict__ expert_ids,
31- int topk_num, int32_t * total_tokens_post_pad, int32_t * num_tokens_per_lora ,
32- int32_t * adapter_enabled, int32_t * lora_ids) {
31+ int topk_num, int32_t * total_tokens_post_pad, int32_t * adapter_enabled ,
32+ int32_t * lora_ids) {
3333 const size_t tokens_per_thread = div_ceil (numel, blockDim .x );
3434 const size_t start_idx = threadIdx .x * tokens_per_thread;
3535
@@ -131,8 +131,8 @@ void moe_lora_align_block_size(
131131 int64_t num_experts, int64_t block_size, int64_t max_loras,
132132 int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
133133 torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
134- torch::Tensor num_tokens_post_pad, torch::Tensor num_tokens_per_lora ,
135- torch::Tensor adapter_enabled, torch::Tensor lora_ids) {
134+ torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled ,
135+ torch::Tensor lora_ids) {
136136 const int topk_num = topk_ids.size (1 );
137137
138138 TORCH_CHECK (block_size > 0 , " block_size should be greater than 0. " );
@@ -169,7 +169,6 @@ void moe_lora_align_block_size(
169169 max_num_m_blocks, sorted_token_ids.data_ptr <int32_t >(),
170170 expert_ids.data_ptr <int32_t >(), topk_num,
171171 num_tokens_post_pad.data_ptr <int32_t >(),
172- num_tokens_per_lora.data_ptr <int32_t >(),
173172 adapter_enabled.data_ptr <int32_t >(), lora_ids.data_ptr <int32_t >());
174173 });
175174}
0 commit comments