@@ -59,7 +59,6 @@ __device__ void moe_fused_gate_impl(
5959 int64_t topk,
6060 int64_t num_fused_shared_experts,
6161 double routed_scaling_factor,
62- bool apply_routed_scaling_factor_on_output,
6362 Params params) {
6463 int tidx = threadIdx .x ;
6564 int64_t thread_row =
@@ -249,9 +248,6 @@ __device__ void moe_fused_gate_impl(
249248 for (int ii = 0 ; ii < topk; ++ii) {
250249 int64_t const idx = topk * thread_row + ii;
251250 output_ptr[idx] = output_ptr[idx] / output_sum;
252- if (apply_routed_scaling_factor_on_output) {
253- output_ptr[idx] *= routed_scaling_factor;
254- }
255251 }
256252 }
257253}
@@ -286,8 +282,7 @@ __global__ void moe_fused_gate_kernel(
286282 int64_t topk_group,
287283 int64_t topk,
288284 int64_t num_fused_shared_experts,
289- double routed_scaling_factor,
290- bool apply_routed_scaling_factor_on_output) {
285+ double routed_scaling_factor) {
291286 KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
292287 moe_fused_gate_impl<T>(
293288 input,
@@ -299,7 +294,6 @@ __global__ void moe_fused_gate_kernel(
299294 topk,
300295 num_fused_shared_experts,
301296 routed_scaling_factor,
302- apply_routed_scaling_factor_on_output,
303297 params);
304298}
305299
@@ -320,8 +314,7 @@ __global__ void moe_fused_gate_kernel(
320314 topk_group, \
321315 topk, \
322316 num_fused_shared_experts, \
323- routed_scaling_factor, \
324- apply_routed_scaling_factor_on_output); \
317+ routed_scaling_factor); \
325318 dispatched = true ; \
326319 } while (0 )
327320
@@ -349,8 +342,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
349342 int64_t topk_group,
350343 int64_t topk,
351344 int64_t num_fused_shared_experts,
352- double routed_scaling_factor,
353- bool apply_routed_scaling_factor_on_output) {
345+ double routed_scaling_factor) {
354346 KernelParamsDynamic params;
355347 params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
356348 params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
@@ -369,7 +361,6 @@ __global__ void moe_fused_gate_kernel_dynamic(
369361 topk,
370362 num_fused_shared_experts,
371363 routed_scaling_factor,
372- apply_routed_scaling_factor_on_output,
373364 params);
374365}
375366
@@ -383,8 +374,7 @@ std::vector<at::Tensor> moe_fused_gate(
383374 int64_t topk_group,
384375 int64_t topk,
385376 int64_t num_fused_shared_experts,
386- double routed_scaling_factor,
387- bool apply_routed_scaling_factor_on_output) {
377+ double routed_scaling_factor) {
388378 int64_t num_rows = input.size (0 );
389379 int32_t num_experts = input.size (1 );
390380 auto options = torch::TensorOptions ().dtype (torch::kFloat32 ).device (torch::kCUDA );
@@ -483,8 +473,7 @@ std::vector<at::Tensor> moe_fused_gate(
483473 topk_group,
484474 topk,
485475 num_fused_shared_experts,
486- routed_scaling_factor,
487- apply_routed_scaling_factor_on_output);
476+ routed_scaling_factor);
488477 } else if (input.scalar_type () == at::kHalf ) {
489478 moe_fused_gate_kernel_dynamic<float16_t ><<<num_blocks, block_dim, 0 , stream>>> (
490479 input.data_ptr (),
@@ -497,8 +486,7 @@ std::vector<at::Tensor> moe_fused_gate(
497486 topk_group,
498487 topk,
499488 num_fused_shared_experts,
500- routed_scaling_factor,
501- apply_routed_scaling_factor_on_output);
489+ routed_scaling_factor);
502490 } else if (input.scalar_type () == at::kFloat ) {
503491 moe_fused_gate_kernel_dynamic<float32_t ><<<num_blocks, block_dim, 0 , stream>>> (
504492 input.data_ptr (),
@@ -511,8 +499,7 @@ std::vector<at::Tensor> moe_fused_gate(
511499 topk_group,
512500 topk,
513501 num_fused_shared_experts,
514- routed_scaling_factor,
515- apply_routed_scaling_factor_on_output);
502+ routed_scaling_factor);
516503 } else {
517504 TORCH_CHECK (false , " Unsupported data type for moe_fused_gate" );
518505 }
0 commit comments