@@ -389,19 +389,19 @@ def add_lora_fused_moe(
389389 top_k_num ,
390390 lora_ids ,
391391 adapter_enabled ,
392- shrink_config .get ("BLOCK_SIZE_M" , 64 ) ,
393- shrink_config .get ("BLOCK_SIZE_N" , 64 ) ,
394- shrink_config .get ("BLOCK_SIZE_K" , 32 ) ,
395- shrink_config .get ("GROUP_SIZE_M" , 8 ) ,
396- shrink_config .get ("NUM_WARPS" , 4 ) ,
397- shrink_config .get ("NUM_STAGES" , 3 ) ,
398- shrink_config .get ("SPLIT_K" , 1 ) ,
399- expand_config .get ("BLOCK_SIZE_M" , 64 ) ,
400- expand_config .get ("BLOCK_SIZE_N" , 64 ) ,
401- expand_config .get ("BLOCK_SIZE_K" , 64 ) ,
402- expand_config .get ("GROUP_SIZE_M" , 64 ) ,
403- expand_config .get ("NUM_WARPS" , 4 ) ,
404- expand_config .get ("NUM_STAGES" , 3 ) ,
405- expand_config .get ("SPLIT_K" , 1 ) ,
392+ shrink_config .get ("BLOCK_SIZE_M" ) or shrink_config . get ( "block_m" ) or 64 ,
393+ shrink_config .get ("BLOCK_SIZE_N" ) or shrink_config . get ( "block_n" ) or 64 ,
394+ shrink_config .get ("BLOCK_SIZE_K" ) or shrink_config . get ( "block_k" ) or 32 ,
395+ shrink_config .get ("GROUP_SIZE_M" ) or shrink_config . get ( "group_m" ) or 8 ,
396+ shrink_config .get ("NUM_WARPS" ) or shrink_config . get ( "num_warps" ) or 4 ,
397+ shrink_config .get ("NUM_STAGES" ) or shrink_config . get ( "num_stages" ) or 3 ,
398+ shrink_config .get ("SPLIT_K" ) or shrink_config . get ( "split_k" ) or 1 ,
399+ expand_config .get ("BLOCK_SIZE_M" ) or expand_config . get ( "block_m" ) or 64 ,
400+ expand_config .get ("BLOCK_SIZE_N" ) or expand_config . get ( "block_n" ) or 64 ,
401+ expand_config .get ("BLOCK_SIZE_K" ) or expand_config . get ( "block_k" ) or 64 ,
402+ expand_config .get ("GROUP_SIZE_M" ) or expand_config . get ( "group_m" ) or 64 ,
403+ expand_config .get ("NUM_WARPS" ) or expand_config . get ( "num_warps" ) or 4 ,
404+ expand_config .get ("NUM_STAGES" ) or expand_config . get ( "num_stages" ) or 3 ,
405+ expand_config .get ("SPLIT_K" ) or expand_config . get ( "split_k" ) or 1 ,
406406 mul_routed_weight ,
407407 )
0 commit comments